Add InetDiagMessage.destroyLiveTcpSocket am: a22e2ed0fa am: d854f708bf

Original change: https://googleplex-android-review.googlesource.com/c/platform/frameworks/libs/net/+/23126578

Change-Id: I7dbe6f973de7d29677841f3ef261c70e6abd92a1
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/common/device/com/android/net/module/util/netlink/InetDiagMessage.java b/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
index 0a2f50d..d462c53 100644
--- a/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
+++ b/common/device/com/android/net/module/util/netlink/InetDiagMessage.java
@@ -19,32 +19,48 @@
 import static android.os.Process.INVALID_UID;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.ENOENT;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 import static android.system.OsConstants.NETLINK_INET_DIAG;
 
+import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
+import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DESTROY;
 import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
+import static com.android.net.module.util.netlink.NetlinkConstants.hexify;
+import static com.android.net.module.util.netlink.NetlinkConstants.stringForAddressFamily;
+import static com.android.net.module.util.netlink.NetlinkConstants.stringForProtocol;
 import static com.android.net.module.util.netlink.NetlinkUtils.DEFAULT_RECV_BUFSIZE;
+import static com.android.net.module.util.netlink.NetlinkUtils.IO_TIMEOUT_MS;
+import static com.android.net.module.util.netlink.NetlinkUtils.TCP_ALIVE_STATE_FILTER;
+import static com.android.net.module.util.netlink.NetlinkUtils.connectSocketToNetlink;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
 import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
 
 import android.net.util.SocketUtils;
+import android.os.Process;
 import android.system.ErrnoException;
 import android.util.Log;
+import android.util.Range;
 
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
+import androidx.annotation.VisibleForTesting;
 
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.io.InterruptedIOException;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Predicate;
 
 /**
  * A NetlinkMessage subclass for netlink inet_diag messages.
@@ -138,7 +154,8 @@
 
     public StructInetDiagMsg inetDiagMsg;
 
-    private InetDiagMessage(@NonNull StructNlMsgHdr header) {
+    @VisibleForTesting
+    public InetDiagMessage(@NonNull StructNlMsgHdr header) {
         super(header);
         inetDiagMsg = new StructInetDiagMsg();
     }
@@ -157,6 +174,13 @@
         return msg;
     }
 
+    private static void closeSocketQuietly(final FileDescriptor fd) {
+        try {
+            SocketUtils.closeSocket(fd);
+        } catch (IOException ignored) {
+        }
+    }
+
     private static int lookupUidByFamily(int protocol, InetSocketAddress local,
                                          InetSocketAddress remote, int family, short flags,
                                          FileDescriptor fd)
@@ -247,13 +271,7 @@
                 | InterruptedIOException e) {
             Log.e(TAG, e.toString());
         } finally {
-            if (fd != null) {
-                try {
-                    SocketUtils.closeSocket(fd);
-                } catch (IOException e) {
-                    Log.e(TAG, e.toString());
-                }
-            }
+            closeSocketQuietly(fd);
         }
         return uid;
     }
@@ -269,7 +287,185 @@
                 (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP) /* flag */,
                 0 /* pad */,
                 1 << NetlinkConstants.INET_DIAG_MEMINFO /* idiagExt */,
-                NetlinkUtils.TCP_MONITOR_STATE_FILTER);
+                TCP_ALIVE_STATE_FILTER);
+    }
+
+    private static void sendNetlinkDestroyRequest(FileDescriptor fd, int proto,
+            InetDiagMessage diagMsg) throws InterruptedIOException, ErrnoException {
+        final byte[] destroyMsg = InetDiagMessage.inetDiagReqV2(
+                proto,
+                diagMsg.inetDiagMsg.id,
+                diagMsg.inetDiagMsg.idiag_family,
+                SOCK_DESTROY,
+                (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_ACK),
+                0 /* pad */,
+                0 /* idiagExt */,
+                1 << diagMsg.inetDiagMsg.idiag_state
+        );
+        NetlinkUtils.sendMessage(fd, destroyMsg, 0, destroyMsg.length, IO_TIMEOUT_MS);
+        NetlinkUtils.receiveNetlinkAck(fd);
+    }
+
+    private static void sendNetlinkDumpRequest(FileDescriptor fd, int proto, int states, int family)
+            throws InterruptedIOException, ErrnoException {
+        final byte[] dumpMsg = InetDiagMessage.inetDiagReqV2(
+                proto,
+                null /* id */,
+                family,
+                SOCK_DIAG_BY_FAMILY,
+                (short) (StructNlMsgHdr.NLM_F_REQUEST | StructNlMsgHdr.NLM_F_DUMP),
+                0 /* pad */,
+                0 /* idiagExt */,
+                states);
+        NetlinkUtils.sendMessage(fd, dumpMsg, 0, dumpMsg.length, IO_TIMEOUT_MS);
+    }
+
+    private static int processNetlinkDumpAndDestroySockets(FileDescriptor dumpFd,
+            FileDescriptor destroyFd, int proto, Predicate<InetDiagMessage> filter)
+            throws InterruptedIOException, ErrnoException {
+        int destroyedSockets = 0;
+
+        while (true) {
+            final ByteBuffer buf = NetlinkUtils.recvMessage(
+                    dumpFd, DEFAULT_RECV_BUFSIZE, IO_TIMEOUT_MS);
+
+            while (buf.remaining() > 0) {
+                final int position = buf.position();
+                final NetlinkMessage nlMsg = NetlinkMessage.parse(buf, NETLINK_INET_DIAG);
+                if (nlMsg == null) {
+                    // Move to the position where parse started for error log.
+                    buf.position(position);
+                    Log.e(TAG, "Failed to parse netlink message: " + hexify(buf));
+                    break;
+                }
+
+                if (nlMsg.getHeader().nlmsg_type == NLMSG_DONE) {
+                    return destroyedSockets;
+                }
+
+                if (!(nlMsg instanceof InetDiagMessage)) {
+                    Log.wtf(TAG, "Received unexpected netlink message: " + nlMsg);
+                    continue;
+                }
+
+                final InetDiagMessage diagMsg = (InetDiagMessage) nlMsg;
+                if (filter.test(diagMsg)) {
+                    try {
+                        sendNetlinkDestroyRequest(destroyFd, proto, diagMsg);
+                        destroyedSockets++;
+                    } catch (InterruptedIOException | ErrnoException e) {
+                        if (!(e instanceof ErrnoException
+                                && ((ErrnoException) e).errno == ENOENT)) {
+                            Log.e(TAG, "Failed to destroy socket: diagMsg=" + diagMsg + ", " + e);
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    /**
+     * Returns whether the InetDiagMessage is for adb socket or not
+     */
+    @VisibleForTesting
+    public static boolean isAdbSocket(final InetDiagMessage msg) {
+        // This is inaccurate since adb could run with ROOT_UID or other services can run with
+        // SHELL_UID. But this check covers most cases and enough.
+        // Note that getting service.adb.tcp.port system property is prohibited by sepolicy
+        // TODO: skip the socket only if there is a listen socket owned by SHELL_UID with the same
+        // source port as this socket
+        return msg.inetDiagMsg.idiag_uid == Process.SHELL_UID;
+    }
+
+    /**
+     * Returns whether the range contains the uid in the InetDiagMessage or not
+     */
+    @VisibleForTesting
+    public static boolean containsUid(InetDiagMessage msg, Set<Range<Integer>> ranges) {
+        for (final Range<Integer> range: ranges) {
+            if (range.contains(msg.inetDiagMsg.idiag_uid)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private static boolean isLoopbackAddress(InetAddress addr) {
+        if (addr.isLoopbackAddress()) return true;
+        if (!(addr instanceof Inet6Address)) return false;
+
+        // Following check is for v4-mapped v6 address. StructInetDiagSockId contains v4-mapped v6
+        // address as Inet6Address, See StructInetDiagSockId#parse
+        final byte[] addrBytes = addr.getAddress();
+        for (int i = 0; i < 10; i++) {
+            if (addrBytes[i] != 0) return false;
+        }
+        return addrBytes[10] == (byte) 0xff
+                && addrBytes[11] == (byte) 0xff
+                && addrBytes[12] == 127;
+    }
+
+    /**
+     * Returns whether the socket address in the InetDiagMessage is loopback or not
+     */
+    @VisibleForTesting
+    public static boolean isLoopback(InetDiagMessage msg) {
+        final InetAddress srcAddr = msg.inetDiagMsg.id.locSocketAddress.getAddress();
+        final InetAddress dstAddr = msg.inetDiagMsg.id.remSocketAddress.getAddress();
+        return isLoopbackAddress(srcAddr)
+                || isLoopbackAddress(dstAddr)
+                || srcAddr.equals(dstAddr);
+    }
+
+    private static void destroySockets(int proto, int states, Predicate<InetDiagMessage> filter)
+            throws ErrnoException, SocketException, InterruptedIOException {
+        FileDescriptor dumpFd = null;
+        FileDescriptor destroyFd = null;
+
+        try {
+            dumpFd = NetlinkUtils.createNetLinkInetDiagSocket();
+            destroyFd = NetlinkUtils.createNetLinkInetDiagSocket();
+            connectSocketToNetlink(dumpFd);
+            connectSocketToNetlink(destroyFd);
+
+            for (int family : List.of(AF_INET, AF_INET6)) {
+                try {
+                    sendNetlinkDumpRequest(dumpFd, proto, states, family);
+                } catch (InterruptedIOException | ErrnoException e) {
+                    Log.e(TAG, "Failed to send netlink dump request: " + e);
+                    continue;
+                }
+                final int destroyedSockets = processNetlinkDumpAndDestroySockets(
+                        dumpFd, destroyFd, proto, filter);
+                Log.d(TAG, "Destroyed " + destroyedSockets + " sockets"
+                        + ", proto=" + stringForProtocol(proto)
+                        + ", family=" + stringForAddressFamily(family)
+                        + ", states=" + states);
+            }
+        } finally {
+            closeSocketQuietly(dumpFd);
+            closeSocketQuietly(destroyFd);
+        }
+    }
+
+    /**
+     * Close tcp sockets that match the following condition
+     *  1. TCP status is one of TCP_ESTABLISHED, TCP_SYN_SENT, and TCP_SYN_RECV
+     *  2. Owner uid of socket is not in the exemptUids
+     *  3. Owner uid of socket is in the ranges
+     *  4. Socket is not loopback
+     *  5. Socket is not adb socket
+     *
+     * @param ranges target uid ranges
+     * @param exemptUids uids to skip close socket
+     */
+    public static void destroyLiveTcpSockets(Set<Range<Integer>> ranges, Set<Integer> exemptUids)
+            throws SocketException, InterruptedIOException, ErrnoException {
+        destroySockets(IPPROTO_TCP, TCP_ALIVE_STATE_FILTER,
+                (diagMsg) -> !exemptUids.contains(diagMsg.inetDiagMsg.idiag_uid)
+                        && containsUid(diagMsg, ranges)
+                        && !isLoopback(diagMsg)
+                        && !isAdbSocket(diagMsg));
     }
 
     @Override
diff --git a/common/device/com/android/net/module/util/netlink/NetlinkUtils.java b/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
index d4bf14a..308ea24 100644
--- a/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
+++ b/common/device/com/android/net/module/util/netlink/NetlinkUtils.java
@@ -56,7 +56,7 @@
     private static final int TCP_SYN_SENT = 2;
     private static final int TCP_SYN_RECV = 3;
 
-    public static final int TCP_MONITOR_STATE_FILTER =
+    public static final int TCP_ALIVE_STATE_FILTER =
             (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
 
     public static final int UNKNOWN_MARK = 0xffffffff;
diff --git a/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java b/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
index 30796d2..65e99f8 100644
--- a/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
+++ b/common/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
@@ -16,6 +16,8 @@
 
 package com.android.net.module.util.netlink;
 
+import static android.os.Process.ROOT_UID;
+import static android.os.Process.SHELL_UID;
 import static android.system.OsConstants.AF_INET;
 import static android.system.OsConstants.AF_INET6;
 import static android.system.OsConstants.IPPROTO_TCP;
@@ -34,6 +36,8 @@
 import static org.junit.Assert.fail;
 
 import android.net.InetAddresses;
+import android.util.ArraySet;
+import android.util.Range;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -46,8 +50,11 @@
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Set;
 
 @RunWith(AndroidJUnit4.class)
 @SmallTest
@@ -543,4 +550,143 @@
                 7  /* ifIndex */,
                 88 /* cookie */);
     }
+
+    private void doTestIsLoopback(InetAddress srcAddr, InetAddress dstAddr, boolean expected) {
+        final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+        inetDiagMsg.inetDiagMsg.id = new StructInetDiagSockId(
+                new InetSocketAddress(srcAddr, 43031),
+                new InetSocketAddress(dstAddr, 38415)
+        );
+
+        assertEquals(expected, InetDiagMessage.isLoopback(inetDiagMsg));
+    }
+
+    @Test
+    public void testIsLoopback() {
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("127.0.0.1"),
+                InetAddresses.parseNumericAddress("192.0.2.1"),
+                true
+        );
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("192.0.2.1"),
+                InetAddresses.parseNumericAddress("127.7.7.7"),
+                true
+        );
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("::1"),
+                InetAddresses.parseNumericAddress("::1"),
+                true
+        );
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("::1"),
+                InetAddresses.parseNumericAddress("2001:db8::1"),
+                true
+        );
+    }
+
+    @Test
+    public void testIsLoopbackSameSrcDstAddress()  {
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("192.0.2.1"),
+                InetAddresses.parseNumericAddress("192.0.2.1"),
+                true
+        );
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("2001:db8::1"),
+                InetAddresses.parseNumericAddress("2001:db8::1"),
+                true
+        );
+    }
+
+    @Test
+    public void testIsLoopbackNonLoopbackSocket()  {
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("192.0.2.1"),
+                InetAddresses.parseNumericAddress("192.0.2.2"),
+                false
+        );
+        doTestIsLoopback(
+                InetAddresses.parseNumericAddress("2001:db8::1"),
+                InetAddresses.parseNumericAddress("2001:db8::2"),
+                false
+        );
+    }
+
+    @Test
+    public void testIsLoopbackV4MappedV6() throws UnknownHostException {
+        // ::FFFF:127.1.2.3
+        final byte[] addrLoopbackByte = {
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+                (byte) 0x7f, (byte) 0x01, (byte) 0x02, (byte) 0x03,
+        };
+        // ::FFFF:192.0.2.1
+        final byte[] addrNonLoopbackByte1 = {
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+                (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x01,
+        };
+        // ::FFFF:192.0.2.2
+        final byte[] addrNonLoopbackByte2 = {
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+                (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x02,
+        };
+
+        final Inet6Address addrLoopback = Inet6Address.getByAddress(null, addrLoopbackByte, -1);
+        final Inet6Address addrNonLoopback1 =
+                Inet6Address.getByAddress(null, addrNonLoopbackByte1, -1);
+        final Inet6Address addrNonLoopback2 =
+                Inet6Address.getByAddress(null, addrNonLoopbackByte2, -1);
+
+        doTestIsLoopback(addrLoopback, addrNonLoopback1, true);
+        doTestIsLoopback(addrNonLoopback1, addrNonLoopback2, false);
+        doTestIsLoopback(addrNonLoopback1, addrNonLoopback1, true);
+    }
+
+    private void doTestContainsUid(final int uid, final Set<Range<Integer>> ranges,
+            final boolean expected) {
+        final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+        inetDiagMsg.inetDiagMsg.idiag_uid = uid;
+        assertEquals(expected, InetDiagMessage.containsUid(inetDiagMsg, ranges));
+    }
+
+    @Test
+    public void testContainsUid() {
+        doTestContainsUid(77 /* uid */,
+                new ArraySet<>(List.of(new Range<>(0, 100))),
+                true /* expected */);
+        doTestContainsUid(77 /* uid */,
+                new ArraySet<>(List.of(new Range<>(77, 77), new Range<>(100, 200))),
+                true /* expected */);
+
+        doTestContainsUid(77 /* uid */,
+                new ArraySet<>(List.of(new Range<>(100, 200))),
+                false /* expected */);
+        doTestContainsUid(77 /* uid */,
+                new ArraySet<>(List.of(new Range<>(0, 76), new Range<>(78, 100))),
+                false /* expected */);
+    }
+
+    private void doTestIsAdbSocket(final int uid, final boolean expected) {
+        final InetDiagMessage inetDiagMsg = new InetDiagMessage(new StructNlMsgHdr());
+        inetDiagMsg.inetDiagMsg.idiag_uid = uid;
+        inetDiagMsg.inetDiagMsg.id = new StructInetDiagSockId(
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::1"), 38417),
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::2"), 38415)
+        );
+        assertEquals(expected, InetDiagMessage.isAdbSocket(inetDiagMsg));
+    }
+
+    @Test
+    public void testIsAdbSocket() {
+        final int appUid = 10108;
+        doTestIsAdbSocket(SHELL_UID,  true /* expected */);
+        doTestIsAdbSocket(ROOT_UID, false /* expected */);
+        doTestIsAdbSocket(appUid, false /* expected */);
+    }
 }