Support configuring the VCN restriction policy

This commit implements the logic to support configuring
VCN restriction policy on the underlying network. This commit
also adds a stub for later CLs to add the configuration interface.

Bug: 239104955
Test: atest FrameworksVcnTests(new tests)
Test: atest CtsVcnTestCases
Change-Id: I1687df029b5bbd8f20e0e4839ba8bd1717051dcf
diff --git a/services/core/java/com/android/server/VcnManagementService.java b/services/core/java/com/android/server/VcnManagementService.java
index 76cac93..6f49db1 100644
--- a/services/core/java/com/android/server/VcnManagementService.java
+++ b/services/core/java/com/android/server/VcnManagementService.java
@@ -18,8 +18,10 @@
 
 import static android.Manifest.permission.DUMP;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_TEST;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.vcn.VcnGatewayConnectionConfig.ALLOWED_CAPABILITIES;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_ACTIVE;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_INACTIVE;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_NOT_CONFIGURED;
@@ -68,6 +70,7 @@
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.util.ArrayMap;
+import android.util.ArraySet;
 import android.util.LocalLog;
 import android.util.Log;
 import android.util.Slog;
@@ -361,6 +364,16 @@
         public LocationPermissionChecker newLocationPermissionChecker(@NonNull Context context) {
             return new LocationPermissionChecker(context);
         }
+
+        /** Gets the transports that need to be marked as restricted by the VCN */
+        public Set<Integer> getRestrictedTransports(
+                ParcelUuid subGrp,
+                Map<ParcelUuid, VcnConfig> vcnConfigs,
+                TelephonySubscriptionSnapshot lastSnapshot) {
+            // TODO: b/239104955 Read restriction policy configurations
+
+            return Collections.singleton(TRANSPORT_WIFI);
+        }
     }
 
     /** Notifies the VcnManagementService that external dependencies can be set up. */
@@ -1000,7 +1013,7 @@
 
             final ParcelUuid subGrp = getSubGroupForNetworkCapabilities(ncCopy);
             boolean isVcnManagedNetwork = false;
-            boolean isRestrictedCarrierWifi = false;
+            boolean isRestricted = false;
             synchronized (mLock) {
                 final Vcn vcn = mVcns.get(subGrp);
                 if (vcn != null) {
@@ -1008,9 +1021,19 @@
                         isVcnManagedNetwork = true;
                     }
 
-                    if (ncCopy.hasTransport(NetworkCapabilities.TRANSPORT_WIFI)) {
-                        // Carrier WiFi always restricted if VCN exists (even in safe mode).
-                        isRestrictedCarrierWifi = true;
+                    final Set<Integer> restrictedTransports =
+                            mDeps.getRestrictedTransports(subGrp, mConfigs, mLastSnapshot);
+                    for (int restrictedTransport : restrictedTransports) {
+                        if (ncCopy.hasTransport(restrictedTransport)) {
+                            if (restrictedTransport == TRANSPORT_CELLULAR) {
+                                // Only make a cell network as restricted when the VCN is in
+                                // active mode.
+                                isRestricted |= (vcn.getStatus() == VCN_STATUS_CODE_ACTIVE);
+                            } else {
+                                isRestricted = true;
+                                break;
+                            }
+                        }
                     }
                 }
             }
@@ -1024,14 +1047,16 @@
                 ncBuilder.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED);
             }
 
-            if (isRestrictedCarrierWifi) {
+            if (isRestricted) {
                 ncBuilder.removeCapability(
                         NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED);
             }
 
             final NetworkCapabilities result = ncBuilder.build();
             final VcnUnderlyingNetworkPolicy policy = new VcnUnderlyingNetworkPolicy(
-                    mTrackingNetworkCallback.requiresRestartForCarrierWifi(result), result);
+                    mTrackingNetworkCallback
+                            .requiresRestartForImmutableCapabilityChanges(result),
+                    result);
 
             logVdbg("getUnderlyingNetworkPolicy() called for caps: " + networkCapabilities
                         + "; and lp: " + linkProperties + "; result = " + policy);
@@ -1296,15 +1321,38 @@
             }
         }
 
-        private boolean requiresRestartForCarrierWifi(NetworkCapabilities caps) {
-            if (!caps.hasTransport(TRANSPORT_WIFI) || caps.getSubscriptionIds() == null) {
+        private Set<Integer> getNonTestTransportTypes(NetworkCapabilities caps) {
+            final Set<Integer> transportTypes = new ArraySet<>();
+            for (int t : caps.getTransportTypes()) {
+                transportTypes.add(t);
+            }
+            return transportTypes;
+        }
+
+        private boolean hasSameTransportsAndCapabilities(
+                NetworkCapabilities caps, NetworkCapabilities capsOther) {
+            if (!Objects.equals(
+                    getNonTestTransportTypes(caps), getNonTestTransportTypes(capsOther))) {
+                return false;
+            }
+
+            for (int capability : ALLOWED_CAPABILITIES) {
+                if (caps.hasCapability(capability) != capsOther.hasCapability(capability)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        private boolean requiresRestartForImmutableCapabilityChanges(NetworkCapabilities caps) {
+            if (caps.getSubscriptionIds() == null) {
                 return false;
             }
 
             synchronized (mCaps) {
                 for (NetworkCapabilities existing : mCaps.values()) {
-                    if (existing.hasTransport(TRANSPORT_WIFI)
-                            && caps.getSubscriptionIds().equals(existing.getSubscriptionIds())) {
+                    if (caps.getSubscriptionIds().equals(existing.getSubscriptionIds())
+                            && hasSameTransportsAndCapabilities(caps, existing)) {
                         // Restart if any immutable capabilities have changed
                         return existing.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
                                 != caps.hasCapability(NET_CAPABILITY_NOT_RESTRICTED);
diff --git a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
index f924b2e..478afe8 100644
--- a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
+++ b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
@@ -17,6 +17,7 @@
 package com.android.server;
 
 import static android.net.ConnectivityManager.NetworkCallback;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_IMS;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
@@ -252,6 +253,10 @@
                 .when(mMockContext)
                 .enforceCallingOrSelfPermission(
                         eq(android.Manifest.permission.NETWORK_FACTORY), any());
+
+        doReturn(Collections.singleton(TRANSPORT_WIFI))
+                .when(mMockDeps)
+                .getRestrictedTransports(any(), any(), any());
     }
 
 
@@ -1032,63 +1037,135 @@
                 new LinkProperties());
     }
 
-    @Test
-    public void testGetUnderlyingNetworkPolicyCellular() throws Exception {
+    private void checkGetUnderlyingNetworkPolicy(
+            int transportType,
+            boolean isTransportRestricted,
+            boolean isActive,
+            boolean expectVcnManaged,
+            boolean expectRestricted)
+            throws Exception {
+
+        final Set<Integer> restrictedTransports = new ArraySet();
+        if (isTransportRestricted) {
+            restrictedTransports.add(transportType);
+        }
+        doReturn(restrictedTransports).when(mMockDeps).getRestrictedTransports(any(), any(), any());
+
         final VcnUnderlyingNetworkPolicy policy =
                 startVcnAndGetPolicyForTransport(
-                        TEST_SUBSCRIPTION_ID, TEST_UUID_2, true /* isActive */, TRANSPORT_CELLULAR);
+                        TEST_SUBSCRIPTION_ID, TEST_UUID_2, isActive, transportType);
+
+        assertFalse(policy.isTeardownRequested());
+        verifyMergedNetworkCapabilities(
+                policy.getMergedNetworkCapabilities(),
+                transportType,
+                expectVcnManaged,
+                expectRestricted);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_unrestrictCell() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_CELLULAR,
+                false /* isTransportRestricted */,
+                true /* isActive */,
+                true /* expectVcnManaged */,
+                false /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_unrestrictCellSafeMode() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_CELLULAR,
+                false /* isTransportRestricted */,
+                false /* isActive */,
+                false /* expectVcnManaged */,
+                false /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_restrictCell() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_CELLULAR,
+                true /* isTransportRestricted */,
+                true /* isActive */,
+                true /* expectVcnManaged */,
+                true /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_restrictCellSafeMode() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_CELLULAR,
+                true /* isTransportRestricted */,
+                false /* isActive */,
+                false /* expectVcnManaged */,
+                false /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_unrestrictWifi() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_WIFI,
+                false /* isTransportRestricted */,
+                true /* isActive */,
+                true /* expectVcnManaged */,
+                false /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_unrestrictWifiSafeMode() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_WIFI,
+                false /* isTransportRestricted */,
+                false /* isActive */,
+                false /* expectVcnManaged */,
+                false /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_restrictWifi() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_WIFI,
+                true /* isTransportRestricted */,
+                true /* isActive */,
+                true /* expectVcnManaged */,
+                true /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicy_restrictWifiSafeMode() throws Exception {
+        checkGetUnderlyingNetworkPolicy(
+                TRANSPORT_WIFI,
+                true /* isTransportRestricted */,
+                false /* isActive */,
+                false /* expectVcnManaged */,
+                true /* expectRestricted */);
+    }
+
+    @Test
+    public void testGetUnderlyingNetworkPolicyCell_restrictWifi() throws Exception {
+        doReturn(Collections.singleton(TRANSPORT_WIFI))
+                .when(mMockDeps)
+                .getRestrictedTransports(any(), any(), any());
+
+        setupSubscriptionAndStartVcn(TEST_SUBSCRIPTION_ID, TEST_UUID_2, true /* isVcnActive */);
+
+        // Get the policy for a cellular network and expect it won't be affected by the wifi
+        // restriction policy
+        final VcnUnderlyingNetworkPolicy policy =
+                mVcnMgmtSvc.getUnderlyingNetworkPolicy(
+                        getNetworkCapabilitiesBuilderForTransport(
+                                        TEST_SUBSCRIPTION_ID, TRANSPORT_CELLULAR)
+                                .build(),
+                        new LinkProperties());
 
         assertFalse(policy.isTeardownRequested());
         verifyMergedNetworkCapabilities(
                 policy.getMergedNetworkCapabilities(),
                 TRANSPORT_CELLULAR,
-                true /* isVcnManaged */,
-                false /* isRestricted */);
-    }
-
-    @Test
-    public void testGetUnderlyingNetworkPolicyCellular_safeMode() throws Exception {
-        final VcnUnderlyingNetworkPolicy policy =
-                startVcnAndGetPolicyForTransport(
-                        TEST_SUBSCRIPTION_ID,
-                        TEST_UUID_2,
-                        false /* isActive */,
-                        TRANSPORT_CELLULAR);
-
-        assertFalse(policy.isTeardownRequested());
-        verifyMergedNetworkCapabilities(
-                policy.getMergedNetworkCapabilities(),
-                NetworkCapabilities.TRANSPORT_CELLULAR,
-                false /* isVcnManaged */,
-                false /* isRestricted */);
-    }
-
-    @Test
-    public void testGetUnderlyingNetworkPolicyWifi() throws Exception {
-        final VcnUnderlyingNetworkPolicy policy =
-                startVcnAndGetPolicyForTransport(
-                        TEST_SUBSCRIPTION_ID, TEST_UUID_2, true /* isActive */, TRANSPORT_WIFI);
-
-        assertFalse(policy.isTeardownRequested());
-        verifyMergedNetworkCapabilities(
-                policy.getMergedNetworkCapabilities(),
-                NetworkCapabilities.TRANSPORT_WIFI,
-                true /* isVcnManaged */,
-                true /* isRestricted */);
-    }
-
-    @Test
-    public void testGetUnderlyingNetworkPolicyVcnWifi_safeMode() throws Exception {
-        final VcnUnderlyingNetworkPolicy policy =
-                startVcnAndGetPolicyForTransport(
-                        TEST_SUBSCRIPTION_ID, TEST_UUID_2, false /* isActive */, TRANSPORT_WIFI);
-
-        assertFalse(policy.isTeardownRequested());
-        verifyMergedNetworkCapabilities(
-                policy.getMergedNetworkCapabilities(),
-                NetworkCapabilities.TRANSPORT_WIFI,
-                false /* isVcnManaged */,
-                true /* isRestricted */);
+                true /* expectVcnManaged */,
+                false /* expectRestricted */);
     }
 
     private void setupTrackedCarrierWifiNetwork(NetworkCapabilities caps) {
@@ -1139,6 +1216,27 @@
     }
 
     @Test
+    public void testGetUnderlyingNetworkPolicyForRestrictedImsWhenUnrestrictingCell()
+            throws Exception {
+        final NetworkCapabilities existingNetworkCaps =
+                getNetworkCapabilitiesBuilderForTransport(TEST_SUBSCRIPTION_ID, TRANSPORT_CELLULAR)
+                        .addCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                        .removeCapability(NET_CAPABILITY_IMS)
+                        .build();
+        setupTrackedCarrierWifiNetwork(existingNetworkCaps);
+
+        final VcnUnderlyingNetworkPolicy policy =
+                mVcnMgmtSvc.getUnderlyingNetworkPolicy(
+                        getNetworkCapabilitiesBuilderForTransport(
+                                        TEST_SUBSCRIPTION_ID, TRANSPORT_CELLULAR)
+                                .addCapability(NET_CAPABILITY_IMS)
+                                .removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                                .build(),
+                        new LinkProperties());
+        assertFalse(policy.isTeardownRequested());
+    }
+
+    @Test
     public void testGetUnderlyingNetworkPolicyNonVcnNetwork() throws Exception {
         setupSubscriptionAndStartVcn(TEST_SUBSCRIPTION_ID, TEST_UUID_1, true /* isActive */);