Add InetDiagMessage.destroyLiveTcpSocket am: a22e2ed0fa am: d854f708bf
Original change: https://googleplex-android-review.googlesource.com/c/platform/frameworks/libs/net/+/23126578
Change-Id: I7dbe6f973de7d29677841f3ef261c70e6abd92a1
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/common/device/com/android/net/module/util/netlink/InetDiagMessage.java b/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
index 0a2f50d..d462c53 100644
--- a/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
+++ b/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
@@ -19,32 +19,48 @@
import static android.os.Process.INVALID_UID;
import static android.system.OsConstants.AF_INET;
import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.ENOENT;
import static android.system.OsConstants.IPPROTO_TCP;
import static android.system.OsConstants.IPPROTO_UDP;
import static android.system.OsConstants.NETLINK_INET_DIAG;
+import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
+import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DESTROY;
import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
+import static com.android.net.module.util.netlink.NetlinkConstants.hexify;
+import static com.android.net.module.util.netlink.NetlinkConstants.stringForAddressFamily;
+import static com.android.net.module.util.netlink.NetlinkConstants.stringForProtocol;
import static com.android.net.module.util.netlink.NetlinkUtils.DEFAULT_RECV_BUFSIZE;
+import static com.android.net.module.util.netlink.NetlinkUtils.IO_TIMEOUT_MS;
+import static com.android.net.module.util.netlink.NetlinkUtils.TCP_ALIVE_STATE_FILTER;
+import static com.android.net.module.util.netlink.NetlinkUtils.connectSocketToNetlink;
import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
import android.net.util.SocketUtils;
+import android.os.Process;
import android.system.ErrnoException;
import android.util.Log;
+import android.util.Range;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
+import androidx.annotation.VisibleForTesting;
import java.io.FileDescriptor;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.Inet4Address;
import java.net.Inet6Address;
+import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Predicate;
/**
* A NetlinkMessage subclass for netlink inet_diag messages.
@@ -138,7 +154,8 @@
public StructInetDiagMsg inetDiagMsg;
- private InetDiagMessage(@NonNull StructNlMsgHdr header) {
+ @VisibleForTesting
+ public InetDiagMessage(@NonNull StructNlMsgHdr header) {
super(header);
inetDiagMsg = new StructInetDiagMsg();
}
@@ -157,6 +174,13 @@
return msg;
}
+ private static void closeSocketQuietly(final FileDescriptor fd) {
+ try {
+ SocketUtils.closeSocket(fd);
+ } catch (IOException ignored) {
+ }
+ }
+
private static int lookupUidByFamily(int protocol, InetSocketAddress local,
InetSocketAddress remote, int family, short flags,
FileDescriptor fd)
@@ -247,13 +271,7 @@
| InterruptedIOException e) {
Log.e(TAG, e.toString());
} finally {
- if (fd != null) {
- try {
- SocketUtils.closeSocket(fd);
- } catch (IOException e) {
- Log.e(TAG, e.toString());
- }
- }
+ closeSocketQuietly(fd);
}
return uid;
}
@@ -269,7 +287,185 @@
(short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP) /* flag */,
0 /* pad */,
1 << NetlinkConstants.INET_DIAG_MEMINFO /* idiagExt */,
- NetlinkUtils.TCP_MONITOR_STATE_FILTER);
+ TCP_ALIVE_STATE_FILTER);
+ }
+
+ private static void sendNetlinkDestroyRequest(FileDescriptor fd, int proto,
+ InetDiagMessage diagMsg) throws InterruptedIOException, ErrnoException {
+ final byte[] destroyMsg = InetDiagMessage.inetDiagReqV2(
+ proto,
+ diagMsg.inetDiagMsg.id,
+ diagMsg.inetDiagMsg.idiag_family,
+ SOCK_DESTROY,
+ (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_ACK),
+ 0 /* pad */,
+ 0 /* idiagExt */,
+ 1 << diagMsg.inetDiagMsg.idiag_state
+ );
+ NetlinkUtils.sendMessage(fd, destroyMsg, 0, destroyMsg.length, IO_TIMEOUT_MS);
+ NetlinkUtils.receiveNetlinkAck(fd);
+ }
+
+ private static void sendNetlinkDumpRequest(FileDescriptor fd, int proto, int states, int family)
+ throws InterruptedIOException, ErrnoException {
+ final byte[] dumpMsg = InetDiagMessage.inetDiagReqV2(
+ proto,
+ null /* id */,
+ family,
+ SOCK_DIAG_BY_FAMILY,
+ (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP),
+ 0 /* pad */,
+ 0 /* idiagExt */,
+ states);
+ NetlinkUtils.sendMessage(fd, dumpMsg, 0, dumpMsg.length, IO_TIMEOUT_MS);
+ }
+
+ private static int processNetlinkDumpAndDestroySockets(FileDescriptor dumpFd,
+ FileDescriptor destroyFd, int proto, Predicate<InetDiagMessage> filter)
+ throws InterruptedIOException, ErrnoException {
+ int destroyedSockets = 0;
+
+ while (true) {
+ final ByteBuffer buf = NetlinkUtils.recvMessage(
+ dumpFd, DEFAULT_RECV_BUFSIZE, IO_TIMEOUT_MS);
+
+ while (buf.remaining() > 0) {
+ final int position = buf.position();
+ final NetlinkMessage nlMsg = NetlinkMessage.parse(buf, NETLINK_INET_DIAG);
+ if (nlMsg == null) {
+ // Move to the position where parse started for error log.
+ buf.position(position);
+ Log.e(TAG, "Failed to parse netlink message: " + hexify(buf));
+ break;
+ }
+
+ if (nlMsg.getHeader().nlmsg_type == NLMSG_DONE) {
+ return destroyedSockets;
+ }
+
+ if (!(nlMsg instanceof InetDiagMessage)) {
+ Log.wtf(TAG, "Received unexpected netlink message: " + nlMsg);
+ continue;
+ }
+
+ final InetDiagMessage diagMsg = (InetDiagMessage) nlMsg;
+ if (filter.test(diagMsg)) {
+ try {
+ sendNetlinkDestroyRequest(destroyFd, proto, diagMsg);
+ destroyedSockets++;
+ } catch (InterruptedIOException | ErrnoException e) {
+ if (!(e instanceof ErrnoException
+ && ((ErrnoException) e).errno == ENOENT)) {
+ Log.e(TAG, "Failed to destroy socket: diagMsg=" + diagMsg + ", " + e);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns whether the InetDiagMessage is for adb socket or not
+ */
+ @VisibleForTesting
+ public static boolean isAdbSocket(final InetDiagMessage msg) {
+ // This is inaccurate since adb could run with ROOT_UID or other services can run with
+ // SHELL_UID. But this check covers most cases and enough.
+ // Note that getting service.adb.tcp.port system property is prohibited by sepolicy
+ // TODO: skip the socket only if there is a listen socket owned by SHELL_UID with the same
+ // source port as this socket
+ return msg.inetDiagMsg.idiag_uid == Process.SHELL_UID;
+ }
+
+ /**
+ * Returns whether the range contains the uid in the InetDiagMessage or not
+ */
+ @VisibleForTesting
+ public static boolean containsUid(InetDiagMessage msg, Set<Range<Integer>> ranges) {
+ for (final Range<Integer> range: ranges) {
+ if (range.contains(msg.inetDiagMsg.idiag_uid)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private static boolean isLoopbackAddress(InetAddress addr) {
+ if (addr.isLoopbackAddress()) return true;
+ if (!(addr instanceof Inet6Address)) return false;
+
+ // Following check is for v4-mapped v6 address. StructInetDiagSockId contains v4-mapped v6
+ // address as Inet6Address, See StructInetDiagSockId#parse
+ final byte[] addrBytes = addr.getAddress();
+ for (int i = 0; i < 10; i++) {
+ if (addrBytes[i] != 0) return false;
+ }
+ return addrBytes[10] == (byte) 0xff
+ && addrBytes[11] == (byte) 0xff
+ && addrBytes[12] == 127;
+ }
+
+ /**
+ * Returns whether the socket address in the InetDiagMessage is loopback or not
+ */
+ @VisibleForTesting
+ public static boolean isLoopback(InetDiagMessage msg) {
+ final InetAddress srcAddr = msg.inetDiagMsg.id.locSocketAddress.getAddress();
+ final InetAddress dstAddr = msg.inetDiagMsg.id.remSocketAddress.getAddress();
+ return isLoopbackAddress(srcAddr)
+ || isLoopbackAddress(dstAddr)
+ || srcAddr.equals(dstAddr);
+ }
+
+ private static void destroySockets(int proto, int states, Predicate<InetDiagMessage> filter)
+ throws ErrnoException, SocketException, InterruptedIOException {
+ FileDescriptor dumpFd = null;
+ FileDescriptor destroyFd = null;
+
+ try {
+ dumpFd = NetlinkUtils.createNetLinkInetDiagSocket();
+ destroyFd = NetlinkUtils.createNetLinkInetDiagSocket();
+ connectSocketToNetlink(dumpFd);
+ connectSocketToNetlink(destroyFd);
+
+ for (int family : List.of(AF_INET, AF_INET6)) {
+ try {
+ sendNetlinkDumpRequest(dumpFd, proto, states, family);
+ } catch (InterruptedIOException | ErrnoException e) {
+ Log.e(TAG, "Failed to send netlink dump request: " + e);
+ continue;
+ }
+ final int destroyedSockets = processNetlinkDumpAndDestroySockets(
+ dumpFd, destroyFd, proto, filter);
+ Log.d(TAG, "Destroyed " + destroyedSockets + " sockets"
+ + ", proto=" + stringForProtocol(proto)
+ + ", family=" + stringForAddressFamily(family)
+ + ", states=" + states);
+ }
+ } finally {
+ closeSocketQuietly(dumpFd);
+ closeSocketQuietly(destroyFd);
+ }
+ }
+
+ /**
+ * Close tcp sockets that match the following condition
+ * 1. TCP status is one of TCP_ESTABLISHED, TCP_SYN_SENT, and TCP_SYN_RECV
+ * 2. Owner uid of socket is not in the exemptUids
+ * 3. Owner uid of socket is in the ranges
+ * 4. Socket is not loopback
+ * 5. Socket is not adb socket
+ *
+ * @param ranges target uid ranges
+ * @param exemptUids uids to skip close socket
+ */
+ public static void destroyLiveTcpSockets(Set<Range<Integer>> ranges, Set<Integer> exemptUids)
+ throws SocketException, InterruptedIOException, ErrnoException {
+ destroySockets(IPPROTO_TCP, TCP_ALIVE_STATE_FILTER,
+ (diagMsg) -> !exemptUids.contains(diagMsg.inetDiagMsg.idiag_uid)
+ && containsUid(diagMsg, ranges)
+ && !isLoopback(diagMsg)
+ && !isAdbSocket(diagMsg));
}
@Override
diff --git a/common/device/com/android/net/module/util/netlink/NetlinkUtils.java b/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
index d4bf14a..308ea24 100644
--- a/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
+++ b/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
@@ -56,7 +56,7 @@
private static final int TCP_SYN_SENT = 2;
private static final int TCP_SYN_RECV = 3;
- public static final int TCP_MONITOR_STATE_FILTER =
+ public static final int TCP_ALIVE_STATE_FILTER =
(1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
public static final int UNKNOWN_MARK = 0xffffffff;
diff --git a/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java b/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
index 30796d2..65e99f8 100644
--- a/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
@@ -16,6 +16,8 @@
package com.android.net.module.util.netlink;
+import static android.os.Process.ROOT_UID;
+import static android.os.Process.SHELL_UID;
import static android.system.OsConstants.AF_INET;
import static android.system.OsConstants.AF_INET6;
import static android.system.OsConstants.IPPROTO_TCP;
@@ -34,6 +36,8 @@
import static org.junit.Assert.fail;
import android.net.InetAddresses;
+import android.util.ArraySet;
+import android.util.Range;
import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;
@@ -46,8 +50,11 @@
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Set;
@RunWith(AndroidJUnit4.class)
@SmallTest
@@ -543,4 +550,143 @@
7 /* ifIndex */,
88 /* cookie */);
}
+
+ private void doTestIsLoopback(InetAddress srcAddr, InetAddress dstAddr, boolean expected) {
+ final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+ inetDiagMsg.inetDiagMsg.id = new StructInetDiagSockId(
+ new InetSocketAddress(srcAddr, 43031),
+ new InetSocketAddress(dstAddr, 38415)
+ );
+
+ assertEquals(expected, InetDiagMessage.isLoopback(inetDiagMsg));
+ }
+
+ @Test
+ public void testIsLoopback() {
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("127.0.0.1"),
+ InetAddresses.parseNumericAddress("192.0.2.1"),
+ true
+ );
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("192.0.2.1"),
+ InetAddresses.parseNumericAddress("127.7.7.7"),
+ true
+ );
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("::1"),
+ InetAddresses.parseNumericAddress("::1"),
+ true
+ );
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("::1"),
+ InetAddresses.parseNumericAddress("2001:db8::1"),
+ true
+ );
+ }
+
+ @Test
+ public void testIsLoopbackSameSrcDstAddress() {
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("192.0.2.1"),
+ InetAddresses.parseNumericAddress("192.0.2.1"),
+ true
+ );
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("2001:db8::1"),
+ InetAddresses.parseNumericAddress("2001:db8::1"),
+ true
+ );
+ }
+
+ @Test
+ public void testIsLoopbackNonLoopbackSocket() {
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("192.0.2.1"),
+ InetAddresses.parseNumericAddress("192.0.2.2"),
+ false
+ );
+ doTestIsLoopback(
+ InetAddresses.parseNumericAddress("2001:db8::1"),
+ InetAddresses.parseNumericAddress("2001:db8::2"),
+ false
+ );
+ }
+
+ @Test
+ public void testIsLoopbackV4MappedV6() throws UnknownHostException {
+ // ::FFFF:127.1.2.3
+ final byte[] addrLoopbackByte = {
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+ (byte) 0x7f, (byte) 0x01, (byte) 0x02, (byte) 0x03,
+ };
+ // ::FFFF:192.0.2.1
+ final byte[] addrNonLoopbackByte1 = {
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+ (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x01,
+ };
+ // ::FFFF:192.0.2.2
+ final byte[] addrNonLoopbackByte2 = {
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+ (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+ (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x02,
+ };
+
+ final Inet6Address addrLoopback = Inet6Address.getByAddress(null, addrLoopbackByte, -1);
+ final Inet6Address addrNonLoopback1 =
+ Inet6Address.getByAddress(null, addrNonLoopbackByte1, -1);
+ final Inet6Address addrNonLoopback2 =
+ Inet6Address.getByAddress(null, addrNonLoopbackByte2, -1);
+
+ doTestIsLoopback(addrLoopback, addrNonLoopback1, true);
+ doTestIsLoopback(addrNonLoopback1, addrNonLoopback2, false);
+ doTestIsLoopback(addrNonLoopback1, addrNonLoopback1, true);
+ }
+
+ private void doTestContainsUid(final int uid, final Set<Range<Integer>> ranges,
+ final boolean expected) {
+ final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+ inetDiagMsg.inetDiagMsg.idiag_uid = uid;
+ assertEquals(expected, InetDiagMessage.containsUid(inetDiagMsg, ranges));
+ }
+
+ @Test
+ public void testContainsUid() {
+ doTestContainsUid(77 /* uid */,
+ new ArraySet<>(List.of(new Range<>(0, 100))),
+ true /* expected */);
+ doTestContainsUid(77 /* uid */,
+ new ArraySet<>(List.of(new Range<>(77, 77), new Range<>(100, 200))),
+ true /* expected */);
+
+ doTestContainsUid(77 /* uid */,
+ new ArraySet<>(List.of(new Range<>(100, 200))),
+ false /* expected */);
+ doTestContainsUid(77 /* uid */,
+ new ArraySet<>(List.of(new Range<>(0, 76), new Range<>(78, 100))),
+ false /* expected */);
+ }
+
+ private void doTestIsAdbSocket(final int uid, final boolean expected) {
+ final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+ inetDiagMsg.inetDiagMsg.idiag_uid = uid;
+ inetDiagMsg.inetDiagMsg.id = new StructInetDiagSockId(
+ new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::1"), 38417),
+ new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::2"), 38415)
+ );
+ assertEquals(expected, InetDiagMessage.isAdbSocket(inetDiagMsg));
+ }
+
+ @Test
+ public void testIsAdbSocket() {
+ final int appUid = 10108;
+ doTestIsAdbSocket(SHELL_UID, true /* expected */);
+ doTestIsAdbSocket(ROOT_UID, false /* expected */);
+ doTestIsAdbSocket(appUid, false /* expected */);
+ }
}