Add restrictions on underlying networks with CarrierConfig

This commit updates the VCN code to use CarrierConfig to
store the restriction policy configuration, and notify
policy listners when the CarrierConfig gets updated.

Bug: 239104955
Test: atest FrameworksVcnTests(new tests)
Test: atest CtsVcnTestCases
Test: manually verified by overriding carrier configs
Change-Id: Iaeb1871b7d9f5c3d1b4e5ec24cfbf79eea58256f
diff --git a/core/java/android/net/vcn/VcnManager.java b/core/java/android/net/vcn/VcnManager.java
index 40e4083..3a7aea5 100644
--- a/core/java/android/net/vcn/VcnManager.java
+++ b/core/java/android/net/vcn/VcnManager.java
@@ -104,12 +104,23 @@
 
     // TODO: Add separate signal strength thresholds for 2.4 GHz and 5GHz
 
+    /**
+     * Key for transports that need to be marked as restricted by the VCN
+     *
+     * <p>Defaults to TRANSPORT_WIFI if the config does not exist
+     *
+     * @hide
+     */
+    public static final String VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY =
+            "vcn_restricted_transports";
+
     /** List of Carrier Config options to extract from Carrier Config bundles. @hide */
     @NonNull
     public static final String[] VCN_RELATED_CARRIER_CONFIG_KEYS =
             new String[] {
                 VCN_NETWORK_SELECTION_WIFI_ENTRY_RSSI_THRESHOLD_KEY,
-                VCN_NETWORK_SELECTION_WIFI_EXIT_RSSI_THRESHOLD_KEY
+                VCN_NETWORK_SELECTION_WIFI_EXIT_RSSI_THRESHOLD_KEY,
+                VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY
             };
 
     private static final Map<
diff --git a/services/core/java/com/android/server/VcnManagementService.java b/services/core/java/com/android/server/VcnManagementService.java
index 6f49db1..61f7f30 100644
--- a/services/core/java/com/android/server/VcnManagementService.java
+++ b/services/core/java/com/android/server/VcnManagementService.java
@@ -22,6 +22,7 @@
 import static android.net.NetworkCapabilities.TRANSPORT_TEST;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.vcn.VcnGatewayConnectionConfig.ALLOWED_CAPABILITIES;
+import static android.net.vcn.VcnManager.VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_ACTIVE;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_INACTIVE;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_NOT_CONFIGURED;
@@ -55,6 +56,7 @@
 import android.net.vcn.VcnUnderlyingNetworkPolicy;
 import android.net.wifi.WifiInfo;
 import android.os.Binder;
+import android.os.Build;
 import android.os.Environment;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -86,6 +88,7 @@
 import com.android.server.vcn.VcnContext;
 import com.android.server.vcn.VcnNetworkProvider;
 import com.android.server.vcn.util.PersistableBundleUtils;
+import com.android.server.vcn.util.PersistableBundleUtils.PersistableBundleWrapper;
 
 import java.io.File;
 import java.io.FileDescriptor;
@@ -162,6 +165,9 @@
     private static final long DUMP_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(5);
     private static final int LOCAL_LOG_LINE_COUNT = 512;
 
+    private static final Set<Integer> RESTRICTED_TRANSPORTS_DEFAULT =
+            Collections.singleton(TRANSPORT_WIFI);
+
     // Public for use in all other VCN classes
     @NonNull public static final LocalLog LOCAL_LOG = new LocalLog(LOCAL_LOG_LINE_COUNT);
 
@@ -367,12 +373,30 @@
 
         /** Gets the transports that need to be marked as restricted by the VCN */
         public Set<Integer> getRestrictedTransports(
-                ParcelUuid subGrp,
-                Map<ParcelUuid, VcnConfig> vcnConfigs,
-                TelephonySubscriptionSnapshot lastSnapshot) {
-            // TODO: b/239104955 Read restriction policy configurations
+                ParcelUuid subGrp, TelephonySubscriptionSnapshot lastSnapshot) {
+            if (!Build.IS_ENG && !Build.IS_USERDEBUG) {
+                return RESTRICTED_TRANSPORTS_DEFAULT;
+            }
 
-            return Collections.singleton(TRANSPORT_WIFI);
+            final PersistableBundleWrapper carrierConfig =
+                    lastSnapshot.getCarrierConfigForSubGrp(subGrp);
+            if (carrierConfig == null) {
+                return RESTRICTED_TRANSPORTS_DEFAULT;
+            }
+
+            final int[] defaultValue =
+                    RESTRICTED_TRANSPORTS_DEFAULT.stream().mapToInt(i -> i).toArray();
+            final int[] restrictedTransportsArray =
+                    carrierConfig.getIntArray(
+                            VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY,
+                            defaultValue);
+
+            // Convert to a boxed set
+            final Set<Integer> restrictedTransports = new ArraySet<>();
+            for (int transport : restrictedTransportsArray) {
+                restrictedTransports.add(transport);
+            }
+            return restrictedTransports;
         }
     }
 
@@ -530,6 +554,7 @@
                     }
                 }
 
+                boolean needNotifyAllPolicyListeners = false;
                 // Schedule teardown of any VCN instances that have lost carrier privileges (after a
                 // delay)
                 for (Entry<ParcelUuid, Vcn> entry : mVcns.entrySet()) {
@@ -577,6 +602,10 @@
                     } else {
                         // If this VCN's status has not changed, update it with the new snapshot
                         entry.getValue().updateSubscriptionSnapshot(mLastSnapshot);
+                        needNotifyAllPolicyListeners |=
+                                !Objects.equals(
+                                        oldSnapshot.getCarrierConfigForSubGrp(subGrp),
+                                        mLastSnapshot.getCarrierConfigForSubGrp(subGrp));
                     }
                 }
 
@@ -586,6 +615,10 @@
                         getSubGroupToSubIdMappings(mLastSnapshot);
                 if (!currSubGrpMappings.equals(oldSubGrpMappings)) {
                     garbageCollectAndWriteVcnConfigsLocked();
+                    needNotifyAllPolicyListeners = true;
+                }
+
+                if (needNotifyAllPolicyListeners) {
                     notifyAllPolicyListenersLocked();
                 }
             }
@@ -930,6 +963,14 @@
         });
     }
 
+    @VisibleForTesting(visibility = Visibility.PRIVATE)
+    void addVcnUnderlyingNetworkPolicyListenerForTest(
+            @NonNull IVcnUnderlyingNetworkPolicyListener listener) {
+        synchronized (mLock) {
+            addVcnUnderlyingNetworkPolicyListener(listener);
+        }
+    }
+
     /** Removes the provided listener from receiving VcnUnderlyingNetworkPolicy updates. */
     @GuardedBy("mLock")
     @Override
@@ -1022,7 +1063,7 @@
                     }
 
                     final Set<Integer> restrictedTransports =
-                            mDeps.getRestrictedTransports(subGrp, mConfigs, mLastSnapshot);
+                            mDeps.getRestrictedTransports(subGrp, mLastSnapshot);
                     for (int restrictedTransport : restrictedTransports) {
                         if (ncCopy.hasTransport(restrictedTransport)) {
                             if (restrictedTransport == TRANSPORT_CELLULAR) {
diff --git a/services/core/java/com/android/server/vcn/TelephonySubscriptionTracker.java b/services/core/java/com/android/server/vcn/TelephonySubscriptionTracker.java
index 5c305c6..ca4a32f 100644
--- a/services/core/java/com/android/server/vcn/TelephonySubscriptionTracker.java
+++ b/services/core/java/com/android/server/vcn/TelephonySubscriptionTracker.java
@@ -390,8 +390,13 @@
             Objects.requireNonNull(privilegedPackages, "privilegedPackages was null");
             Objects.requireNonNull(subIdToCarrierConfigMap, "subIdToCarrierConfigMap was null");
 
-            mSubIdToInfoMap = Collections.unmodifiableMap(subIdToInfoMap);
-            mSubIdToCarrierConfigMap = Collections.unmodifiableMap(subIdToCarrierConfigMap);
+            mSubIdToInfoMap =
+                    Collections.unmodifiableMap(
+                            new HashMap<Integer, SubscriptionInfo>(subIdToInfoMap));
+            mSubIdToCarrierConfigMap =
+                    Collections.unmodifiableMap(
+                            new HashMap<Integer, PersistableBundleWrapper>(
+                                    subIdToCarrierConfigMap));
 
             final Map<ParcelUuid, Set<String>> unmodifiableInnerSets = new ArrayMap<>();
             for (Entry<ParcelUuid, Set<String>> entry : privilegedPackages.entrySet()) {
diff --git a/services/core/java/com/android/server/vcn/util/PersistableBundleUtils.java b/services/core/java/com/android/server/vcn/util/PersistableBundleUtils.java
index 999d406..d22ec0a 100644
--- a/services/core/java/com/android/server/vcn/util/PersistableBundleUtils.java
+++ b/services/core/java/com/android/server/vcn/util/PersistableBundleUtils.java
@@ -544,6 +544,20 @@
             return mBundle.getInt(key, defaultValue);
         }
 
+        /**
+         * Returns the value associated with the given key, or null if no mapping of the desired
+         * type exists for the given key or a null value is explicitly associated with the key.
+         *
+         * @param key a String, or null
+         * @param defaultValue the value to return if key does not exist
+         * @return an int[] value, or null
+         */
+        @Nullable
+        public int[] getIntArray(@Nullable String key, @Nullable int[] defaultValue) {
+            final int[] value = mBundle.getIntArray(key);
+            return value == null ? defaultValue : value;
+        }
+
         @Override
         public int hashCode() {
             return getHashCode(mBundle);
diff --git a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
index 478afe8..eae8873 100644
--- a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
+++ b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
@@ -22,6 +22,7 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.vcn.VcnManager.VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_ACTIVE;
 import static android.net.vcn.VcnManager.VCN_STATUS_CODE_SAFE_MODE;
 import static android.telephony.SubscriptionManager.INVALID_SUBSCRIPTION_ID;
@@ -96,6 +97,7 @@
 import com.android.server.vcn.VcnContext;
 import com.android.server.vcn.VcnNetworkProvider;
 import com.android.server.vcn.util.PersistableBundleUtils;
+import com.android.server.vcn.util.PersistableBundleUtils.PersistableBundleWrapper;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -256,7 +258,7 @@
 
         doReturn(Collections.singleton(TRANSPORT_WIFI))
                 .when(mMockDeps)
-                .getRestrictedTransports(any(), any(), any());
+                .getRestrictedTransports(any(), any());
     }
 
 
@@ -1037,6 +1039,59 @@
                 new LinkProperties());
     }
 
+    private void checkGetRestrictedTransports(
+            ParcelUuid subGrp,
+            TelephonySubscriptionSnapshot lastSnapshot,
+            Set<Integer> expectedTransports) {
+        Set<Integer> result =
+                new VcnManagementService.Dependencies()
+                        .getRestrictedTransports(subGrp, lastSnapshot);
+        assertEquals(expectedTransports, result);
+    }
+
+    @Test
+    public void testGetRestrictedTransports() {
+        final Set<Integer> restrictedTransports = new ArraySet<>();
+        restrictedTransports.add(TRANSPORT_CELLULAR);
+        restrictedTransports.add(TRANSPORT_WIFI);
+
+        PersistableBundle carrierConfigBundle = new PersistableBundle();
+        carrierConfigBundle.putIntArray(
+                VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY,
+                restrictedTransports.stream().mapToInt(i -> i).toArray());
+        final PersistableBundleWrapper carrierConfig =
+                new PersistableBundleWrapper(carrierConfigBundle);
+
+        final TelephonySubscriptionSnapshot lastSnapshot =
+                mock(TelephonySubscriptionSnapshot.class);
+        doReturn(carrierConfig).when(lastSnapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
+
+        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+    }
+
+    @Test
+    public void testGetRestrictedTransports_noRestrictPolicyConfigured() {
+        final Set<Integer> restrictedTransports = Collections.singleton(TRANSPORT_WIFI);
+
+        final PersistableBundleWrapper carrierConfig =
+                new PersistableBundleWrapper(new PersistableBundle());
+        final TelephonySubscriptionSnapshot lastSnapshot =
+                mock(TelephonySubscriptionSnapshot.class);
+        doReturn(carrierConfig).when(lastSnapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
+
+        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+    }
+
+    @Test
+    public void testGetRestrictedTransports_noCarrierConfig() {
+        final Set<Integer> restrictedTransports = Collections.singleton(TRANSPORT_WIFI);
+
+        final TelephonySubscriptionSnapshot lastSnapshot =
+                mock(TelephonySubscriptionSnapshot.class);
+
+        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+    }
+
     private void checkGetUnderlyingNetworkPolicy(
             int transportType,
             boolean isTransportRestricted,
@@ -1049,7 +1104,7 @@
         if (isTransportRestricted) {
             restrictedTransports.add(transportType);
         }
-        doReturn(restrictedTransports).when(mMockDeps).getRestrictedTransports(any(), any(), any());
+        doReturn(restrictedTransports).when(mMockDeps).getRestrictedTransports(any(), any());
 
         final VcnUnderlyingNetworkPolicy policy =
                 startVcnAndGetPolicyForTransport(
@@ -1147,7 +1202,7 @@
     public void testGetUnderlyingNetworkPolicyCell_restrictWifi() throws Exception {
         doReturn(Collections.singleton(TRANSPORT_WIFI))
                 .when(mMockDeps)
-                .getRestrictedTransports(any(), any(), any());
+                .getRestrictedTransports(any(), any());
 
         setupSubscriptionAndStartVcn(TEST_SUBSCRIPTION_ID, TEST_UUID_2, true /* isVcnActive */);
 
@@ -1316,6 +1371,30 @@
         verify(mMockPolicyListener).onPolicyChanged();
     }
 
+    @Test
+    public void testVcnCarrierConfigChangeUpdatesPolicyListener() throws Exception {
+        setupActiveSubscription(TEST_UUID_2);
+
+        mVcnMgmtSvc.setVcnConfig(TEST_UUID_2, TEST_VCN_CONFIG, TEST_PACKAGE_NAME);
+        mVcnMgmtSvc.addVcnUnderlyingNetworkPolicyListenerForTest(mMockPolicyListener);
+
+        final TelephonySubscriptionSnapshot snapshot =
+                buildSubscriptionSnapshot(
+                        TEST_SUBSCRIPTION_ID,
+                        TEST_UUID_2,
+                        Collections.singleton(TEST_UUID_2),
+                        Collections.emptyMap(),
+                        true /* hasCarrierPrivileges */);
+
+        final PersistableBundleWrapper mockCarrierConfig = mock(PersistableBundleWrapper.class);
+        doReturn(mockCarrierConfig).when(snapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
+
+        final TelephonySubscriptionTrackerCallback cb = getTelephonySubscriptionTrackerCallback();
+        cb.onNewSnapshot(snapshot);
+
+        verify(mMockPolicyListener).onPolicyChanged();
+    }
+
     private void triggerVcnSafeMode(
             @NonNull ParcelUuid subGroup,
             @NonNull TelephonySubscriptionSnapshot snapshot,
diff --git a/tests/vcn/java/com/android/server/vcn/TelephonySubscriptionTrackerTest.java b/tests/vcn/java/com/android/server/vcn/TelephonySubscriptionTrackerTest.java
index 09080be..965b073 100644
--- a/tests/vcn/java/com/android/server/vcn/TelephonySubscriptionTrackerTest.java
+++ b/tests/vcn/java/com/android/server/vcn/TelephonySubscriptionTrackerTest.java
@@ -16,6 +16,9 @@
 
 package com.android.server.vcn;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.vcn.VcnManager.VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY;
 import static android.telephony.CarrierConfigManager.ACTION_CARRIER_CONFIG_CHANGED;
 import static android.telephony.CarrierConfigManager.EXTRA_SLOT_INDEX;
 import static android.telephony.CarrierConfigManager.EXTRA_SUBSCRIPTION_INDEX;
@@ -39,6 +42,7 @@
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -494,6 +498,37 @@
     }
 
     @Test
+    public void testCarrierConfigUpdatedAfterValidTriggersCallbacks() throws Exception {
+        mTelephonySubscriptionTracker.onReceive(mContext, buildTestBroadcastIntent(true));
+        mTestLooper.dispatchAll();
+        verify(mCallback).onNewSnapshot(eq(buildExpectedSnapshot(TEST_PRIVILEGED_PACKAGES)));
+        reset(mCallback);
+
+        final PersistableBundle updatedConfig = new PersistableBundle();
+        updatedConfig.putIntArray(
+                VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY,
+                new int[] {TRANSPORT_WIFI, TRANSPORT_CELLULAR});
+        doReturn(updatedConfig)
+                .when(mCarrierConfigManager)
+                .getConfigForSubId(eq(TEST_SUBSCRIPTION_ID_1));
+
+        Map<Integer, PersistableBundleWrapper> subIdToCarrierConfigMap = new HashMap<>();
+        subIdToCarrierConfigMap.put(
+                TEST_SUBSCRIPTION_ID_1, new PersistableBundleWrapper(updatedConfig));
+        mTelephonySubscriptionTracker.onReceive(mContext, buildTestBroadcastIntent(true));
+        mTestLooper.dispatchAll();
+
+        verify(mCallback)
+                .onNewSnapshot(
+                        eq(
+                                buildExpectedSnapshot(
+                                        0,
+                                        TEST_SUBID_TO_INFO_MAP,
+                                        subIdToCarrierConfigMap,
+                                        TEST_PRIVILEGED_PACKAGES)));
+    }
+
+    @Test
     public void testSlotClearedAfterValidTriggersCallbacks() throws Exception {
         mTelephonySubscriptionTracker.onReceive(mContext, buildTestBroadcastIntent(true));
         mTestLooper.dispatchAll();