Add restrictions on underlying networks with VcnConfig

This commit allows VCN callers to configure underlying networks
that need to be restricted via VcnConfig.

This commit also makes sure that VCN will notify policy listners
when the VcnConfig gets updated.

Bug: 239104955
Test: atest FrameworksVcnTests(new tests)
Test: atest CtsVcnTestCases
Change-Id: Ie174f7ec27ba115939c4f5d88d9bb00c6d348ea9
diff --git a/core/java/android/net/vcn/VcnConfig.java b/core/java/android/net/vcn/VcnConfig.java
index fd3fe37..8627f5c 100644
--- a/core/java/android/net/vcn/VcnConfig.java
+++ b/core/java/android/net/vcn/VcnConfig.java
@@ -15,7 +15,12 @@
  */
 package android.net.vcn;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import static com.android.internal.annotations.VisibleForTesting.Visibility;
+import static com.android.server.vcn.util.PersistableBundleUtils.INTEGER_DESERIALIZER;
+import static com.android.server.vcn.util.PersistableBundleUtils.INTEGER_SERIALIZER;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -25,6 +30,7 @@
 import android.os.Parcelable;
 import android.os.PersistableBundle;
 import android.util.ArraySet;
+import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.Preconditions;
@@ -32,6 +38,7 @@
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Iterator;
 import java.util.Objects;
 import java.util.Set;
 
@@ -46,22 +53,36 @@
 public final class VcnConfig implements Parcelable {
     @NonNull private static final String TAG = VcnConfig.class.getSimpleName();
 
+    private static final Set<Integer> ALLOWED_TRANSPORTS = new ArraySet<>();
+
+    static {
+        ALLOWED_TRANSPORTS.add(TRANSPORT_WIFI);
+        ALLOWED_TRANSPORTS.add(TRANSPORT_CELLULAR);
+    }
+
     private static final String PACKAGE_NAME_KEY = "mPackageName";
     @NonNull private final String mPackageName;
 
     private static final String GATEWAY_CONNECTION_CONFIGS_KEY = "mGatewayConnectionConfigs";
     @NonNull private final Set<VcnGatewayConnectionConfig> mGatewayConnectionConfigs;
 
+    private static final Set<Integer> RESTRICTED_TRANSPORTS_DEFAULT =
+            Collections.singleton(TRANSPORT_WIFI);
+    private static final String RESTRICTED_TRANSPORTS_KEY = "mRestrictedTransports";
+    @NonNull private final Set<Integer> mRestrictedTransports;
+
     private static final String IS_TEST_MODE_PROFILE_KEY = "mIsTestModeProfile";
     private final boolean mIsTestModeProfile;
 
     private VcnConfig(
             @NonNull String packageName,
             @NonNull Set<VcnGatewayConnectionConfig> gatewayConnectionConfigs,
+            @NonNull Set<Integer> restrictedTransports,
             boolean isTestModeProfile) {
         mPackageName = packageName;
         mGatewayConnectionConfigs =
                 Collections.unmodifiableSet(new ArraySet<>(gatewayConnectionConfigs));
+        mRestrictedTransports = Collections.unmodifiableSet(new ArraySet<>(restrictedTransports));
         mIsTestModeProfile = isTestModeProfile;
 
         validate();
@@ -82,6 +103,20 @@
                 new ArraySet<>(
                         PersistableBundleUtils.toList(
                                 gatewayConnectionConfigsBundle, VcnGatewayConnectionConfig::new));
+
+        final PersistableBundle restrictedTransportsBundle =
+                in.getPersistableBundle(RESTRICTED_TRANSPORTS_KEY);
+        if (restrictedTransportsBundle == null) {
+            // RESTRICTED_TRANSPORTS_KEY was added in U and does not exist in VcnConfigs created in
+            // older platforms
+            mRestrictedTransports = RESTRICTED_TRANSPORTS_DEFAULT;
+        } else {
+            mRestrictedTransports =
+                    new ArraySet<Integer>(
+                            PersistableBundleUtils.toList(
+                                    restrictedTransportsBundle, INTEGER_DESERIALIZER));
+        }
+
         mIsTestModeProfile = in.getBoolean(IS_TEST_MODE_PROFILE_KEY);
 
         validate();
@@ -91,6 +126,19 @@
         Objects.requireNonNull(mPackageName, "packageName was null");
         Preconditions.checkCollectionNotEmpty(
                 mGatewayConnectionConfigs, "gatewayConnectionConfigs was empty");
+
+        final Iterator<Integer> iterator = mRestrictedTransports.iterator();
+        while (iterator.hasNext()) {
+            final int transport = iterator.next();
+            if (!ALLOWED_TRANSPORTS.contains(transport)) {
+                iterator.remove();
+                Log.w(
+                        TAG,
+                        "Found invalid transport "
+                                + transport
+                                + " which might be from a new version of VcnConfig");
+            }
+        }
     }
 
     /**
@@ -110,6 +158,17 @@
     }
 
     /**
+     * Retrieve the transports that need to be restricted by VCN.
+     *
+     * @see Builder#setRestrictedUnderlyingNetworkTransports(Set)
+     * @hide
+     */
+    @NonNull
+    public Set<Integer> getRestrictedUnderlyingNetworkTransports() {
+        return Collections.unmodifiableSet(mRestrictedTransports);
+    }
+
+    /**
      * Returns whether or not this VcnConfig is restricted to test networks.
      *
      * @hide
@@ -134,6 +193,12 @@
                         new ArrayList<>(mGatewayConnectionConfigs),
                         VcnGatewayConnectionConfig::toPersistableBundle);
         result.putPersistableBundle(GATEWAY_CONNECTION_CONFIGS_KEY, gatewayConnectionConfigsBundle);
+
+        final PersistableBundle restrictedTransportsBundle =
+                PersistableBundleUtils.fromList(
+                        new ArrayList<>(mRestrictedTransports), INTEGER_SERIALIZER);
+        result.putPersistableBundle(RESTRICTED_TRANSPORTS_KEY, restrictedTransportsBundle);
+
         result.putBoolean(IS_TEST_MODE_PROFILE_KEY, mIsTestModeProfile);
 
         return result;
@@ -141,7 +206,8 @@
 
     @Override
     public int hashCode() {
-        return Objects.hash(mPackageName, mGatewayConnectionConfigs, mIsTestModeProfile);
+        return Objects.hash(
+                mPackageName, mGatewayConnectionConfigs, mRestrictedTransports, mIsTestModeProfile);
     }
 
     @Override
@@ -153,6 +219,7 @@
         final VcnConfig rhs = (VcnConfig) other;
         return mPackageName.equals(rhs.mPackageName)
                 && mGatewayConnectionConfigs.equals(rhs.mGatewayConnectionConfigs)
+                && mRestrictedTransports.equals(rhs.mRestrictedTransports)
                 && mIsTestModeProfile == rhs.mIsTestModeProfile;
     }
 
@@ -189,12 +256,15 @@
         @NonNull
         private final Set<VcnGatewayConnectionConfig> mGatewayConnectionConfigs = new ArraySet<>();
 
+        @NonNull private final Set<Integer> mRestrictedTransports = new ArraySet<>();
+
         private boolean mIsTestModeProfile = false;
 
         public Builder(@NonNull Context context) {
             Objects.requireNonNull(context, "context was null");
 
             mPackageName = context.getOpPackageName();
+            mRestrictedTransports.addAll(RESTRICTED_TRANSPORTS_DEFAULT);
         }
 
         /**
@@ -225,6 +295,37 @@
             return this;
         }
 
+        private void validateRestrictedTransportsOrThrow(Set<Integer> restrictedTransports) {
+            Objects.requireNonNull(restrictedTransports, "transports was null");
+
+            for (int transport : restrictedTransports) {
+                if (!ALLOWED_TRANSPORTS.contains(transport)) {
+                    throw new IllegalArgumentException("Invalid transport " + transport);
+                }
+            }
+        }
+
+        /**
+         * Sets transports that need to be restricted by VCN.
+         *
+         * @param transports transports that need to be restricted by VCN. Networks that include any
+         *     of the transports will be marked as restricted. Only {@link
+         *     NetworkCapabilities#TRANSPORT_WIFI} and {@link
+         *     NetworkCapabilities#TRANSPORT_CELLULAR} are allowed. {@link
+         *     NetworkCapabilities#TRANSPORT_WIFI} is marked restricted by default.
+         * @return this {@link Builder} instance, for chaining
+         * @throws IllegalArgumentException if the input contains unsupported transport types.
+         * @hide
+         */
+        @NonNull
+        public Builder setRestrictedUnderlyingNetworkTransports(@NonNull Set<Integer> transports) {
+            validateRestrictedTransportsOrThrow(transports);
+
+            mRestrictedTransports.clear();
+            mRestrictedTransports.addAll(transports);
+            return this;
+        }
+
         /**
          * Restricts this VcnConfig to matching with test networks (only).
          *
@@ -248,7 +349,11 @@
          */
         @NonNull
         public VcnConfig build() {
-            return new VcnConfig(mPackageName, mGatewayConnectionConfigs, mIsTestModeProfile);
+            return new VcnConfig(
+                    mPackageName,
+                    mGatewayConnectionConfigs,
+                    mRestrictedTransports,
+                    mIsTestModeProfile);
         }
     }
 }
diff --git a/services/core/java/com/android/server/VcnManagementService.java b/services/core/java/com/android/server/VcnManagementService.java
index 61f7f30..f652cb0 100644
--- a/services/core/java/com/android/server/VcnManagementService.java
+++ b/services/core/java/com/android/server/VcnManagementService.java
@@ -371,8 +371,9 @@
             return new LocationPermissionChecker(context);
         }
 
-        /** Gets the transports that need to be marked as restricted by the VCN */
-        public Set<Integer> getRestrictedTransports(
+        /** Gets transports that need to be marked as restricted by the VCN from CarrierConfig */
+        @VisibleForTesting(visibility = Visibility.PRIVATE)
+        public Set<Integer> getRestrictedTransportsFromCarrierConfig(
                 ParcelUuid subGrp, TelephonySubscriptionSnapshot lastSnapshot) {
             if (!Build.IS_ENG && !Build.IS_USERDEBUG) {
                 return RESTRICTED_TRANSPORTS_DEFAULT;
@@ -398,6 +399,22 @@
             }
             return restrictedTransports;
         }
+
+        /** Gets the transports that need to be marked as restricted by the VCN */
+        public Set<Integer> getRestrictedTransports(
+                ParcelUuid subGrp,
+                TelephonySubscriptionSnapshot lastSnapshot,
+                VcnConfig vcnConfig) {
+            final Set<Integer> restrictedTransports = new ArraySet<>();
+            restrictedTransports.addAll(vcnConfig.getRestrictedUnderlyingNetworkTransports());
+
+            // TODO: b/262269892 Remove the ability to configure restricted transports
+            // via CarrierConfig
+            restrictedTransports.addAll(
+                    getRestrictedTransportsFromCarrierConfig(subGrp, lastSnapshot));
+
+            return restrictedTransports;
+        }
     }
 
     /** Notifies the VcnManagementService that external dependencies can be set up. */
@@ -719,6 +736,7 @@
         if (mVcns.containsKey(subscriptionGroup)) {
             final Vcn vcn = mVcns.get(subscriptionGroup);
             vcn.updateConfig(config);
+            notifyAllPolicyListenersLocked();
         } else {
             // TODO(b/193687515): Support multiple VCNs active at the same time
             if (isActiveSubGroup(subscriptionGroup, mLastSnapshot)) {
@@ -936,7 +954,6 @@
     }
 
     /** Adds the provided listener for receiving VcnUnderlyingNetworkPolicy updates. */
-    @GuardedBy("mLock")
     @Override
     public void addVcnUnderlyingNetworkPolicyListener(
             @NonNull IVcnUnderlyingNetworkPolicyListener listener) {
@@ -963,16 +980,7 @@
         });
     }
 
-    @VisibleForTesting(visibility = Visibility.PRIVATE)
-    void addVcnUnderlyingNetworkPolicyListenerForTest(
-            @NonNull IVcnUnderlyingNetworkPolicyListener listener) {
-        synchronized (mLock) {
-            addVcnUnderlyingNetworkPolicyListener(listener);
-        }
-    }
-
     /** Removes the provided listener from receiving VcnUnderlyingNetworkPolicy updates. */
-    @GuardedBy("mLock")
     @Override
     public void removeVcnUnderlyingNetworkPolicyListener(
             @NonNull IVcnUnderlyingNetworkPolicyListener listener) {
@@ -1062,8 +1070,8 @@
                         isVcnManagedNetwork = true;
                     }
 
-                    final Set<Integer> restrictedTransports =
-                            mDeps.getRestrictedTransports(subGrp, mLastSnapshot);
+                    final Set<Integer> restrictedTransports = mDeps.getRestrictedTransports(
+                            subGrp, mLastSnapshot, mConfigs.get(subGrp));
                     for (int restrictedTransport : restrictedTransports) {
                         if (ncCopy.hasTransport(restrictedTransport)) {
                             if (restrictedTransport == TRANSPORT_CELLULAR) {
diff --git a/tests/vcn/java/android/net/vcn/VcnConfigTest.java b/tests/vcn/java/android/net/vcn/VcnConfigTest.java
index 7ac51b7..b313c9f 100644
--- a/tests/vcn/java/android/net/vcn/VcnConfigTest.java
+++ b/tests/vcn/java/android/net/vcn/VcnConfigTest.java
@@ -16,7 +16,12 @@
 
 package android.net.vcn;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -24,6 +29,7 @@
 import android.annotation.NonNull;
 import android.content.Context;
 import android.os.Parcel;
+import android.util.ArraySet;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -42,19 +48,36 @@
     private static final Set<VcnGatewayConnectionConfig> GATEWAY_CONNECTION_CONFIGS =
             Collections.singleton(VcnGatewayConnectionConfigTest.buildTestConfig());
 
+    private static final Set<Integer> RESTRICTED_TRANSPORTS = new ArraySet<>();
+
+    static {
+        RESTRICTED_TRANSPORTS.add(TRANSPORT_WIFI);
+        RESTRICTED_TRANSPORTS.add(TRANSPORT_CELLULAR);
+    }
+
     private final Context mContext = mock(Context.class);
 
     // Public visibility for VcnManagementServiceTest
-    public static VcnConfig buildTestConfig(@NonNull Context context) {
+    public static VcnConfig buildTestConfig(
+            @NonNull Context context, Set<Integer> restrictedTransports) {
         VcnConfig.Builder builder = new VcnConfig.Builder(context);
 
         for (VcnGatewayConnectionConfig gatewayConnectionConfig : GATEWAY_CONNECTION_CONFIGS) {
             builder.addGatewayConnectionConfig(gatewayConnectionConfig);
         }
 
+        if (restrictedTransports != null) {
+            builder.setRestrictedUnderlyingNetworkTransports(restrictedTransports);
+        }
+
         return builder.build();
     }
 
+    // Public visibility for VcnManagementServiceTest
+    public static VcnConfig buildTestConfig(@NonNull Context context) {
+        return buildTestConfig(context, null);
+    }
+
     @Before
     public void setUp() throws Exception {
         doReturn(TEST_PACKAGE_NAME).when(mContext).getOpPackageName();
@@ -91,11 +114,25 @@
     }
 
     @Test
-    public void testBuilderAndGetters() {
+    public void testBuilderAndGettersDefaultValues() {
         final VcnConfig config = buildTestConfig(mContext);
 
         assertEquals(TEST_PACKAGE_NAME, config.getProvisioningPackageName());
         assertEquals(GATEWAY_CONNECTION_CONFIGS, config.getGatewayConnectionConfigs());
+        assertFalse(config.isTestModeProfile());
+        assertEquals(
+                Collections.singleton(TRANSPORT_WIFI),
+                config.getRestrictedUnderlyingNetworkTransports());
+    }
+
+    @Test
+    public void testBuilderAndGettersConfigRestrictedTransports() {
+        final VcnConfig config = buildTestConfig(mContext, RESTRICTED_TRANSPORTS);
+
+        assertEquals(TEST_PACKAGE_NAME, config.getProvisioningPackageName());
+        assertEquals(GATEWAY_CONNECTION_CONFIGS, config.getGatewayConnectionConfigs());
+        assertFalse(config.isTestModeProfile());
+        assertEquals(RESTRICTED_TRANSPORTS, config.getRestrictedUnderlyingNetworkTransports());
     }
 
     @Test
@@ -106,6 +143,24 @@
     }
 
     @Test
+    public void testPersistableBundleWithRestrictedTransports() {
+        final VcnConfig config = buildTestConfig(mContext, RESTRICTED_TRANSPORTS);
+
+        assertEquals(config, new VcnConfig(config.toPersistableBundle()));
+    }
+
+    @Test
+    public void testEqualityWithRestrictedTransports() {
+        final VcnConfig config = buildTestConfig(mContext, RESTRICTED_TRANSPORTS);
+        final VcnConfig configEqual = buildTestConfig(mContext, RESTRICTED_TRANSPORTS);
+        final VcnConfig configNotEqual =
+                buildTestConfig(mContext, Collections.singleton(TRANSPORT_WIFI));
+
+        assertEquals(config, configEqual);
+        assertNotEquals(config, configNotEqual);
+    }
+
+    @Test
     public void testParceling() {
         final VcnConfig config = buildTestConfig(mContext);
 
diff --git a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
index 258642ac..075bc5e 100644
--- a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
+++ b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java
@@ -258,7 +258,7 @@
 
         doReturn(Collections.singleton(TRANSPORT_WIFI))
                 .when(mMockDeps)
-                .getRestrictedTransports(any(), any());
+                .getRestrictedTransports(any(), any(), any());
     }
 
 
@@ -1038,18 +1038,18 @@
                 new LinkProperties());
     }
 
-    private void checkGetRestrictedTransports(
+    private void checkGetRestrictedTransportsFromCarrierConfig(
             ParcelUuid subGrp,
             TelephonySubscriptionSnapshot lastSnapshot,
             Set<Integer> expectedTransports) {
         Set<Integer> result =
                 new VcnManagementService.Dependencies()
-                        .getRestrictedTransports(subGrp, lastSnapshot);
+                        .getRestrictedTransportsFromCarrierConfig(subGrp, lastSnapshot);
         assertEquals(expectedTransports, result);
     }
 
     @Test
-    public void testGetRestrictedTransports() {
+    public void testGetRestrictedTransportsFromCarrierConfig() {
         final Set<Integer> restrictedTransports = new ArraySet<>();
         restrictedTransports.add(TRANSPORT_CELLULAR);
         restrictedTransports.add(TRANSPORT_WIFI);
@@ -1065,11 +1065,12 @@
                 mock(TelephonySubscriptionSnapshot.class);
         doReturn(carrierConfig).when(lastSnapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
 
-        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+        checkGetRestrictedTransportsFromCarrierConfig(
+                TEST_UUID_2, lastSnapshot, restrictedTransports);
     }
 
     @Test
-    public void testGetRestrictedTransports_noRestrictPolicyConfigured() {
+    public void testGetRestrictedTransportsFromCarrierConfig_noRestrictPolicyConfigured() {
         final Set<Integer> restrictedTransports = Collections.singleton(TRANSPORT_WIFI);
 
         final PersistableBundleWrapper carrierConfig =
@@ -1078,17 +1079,54 @@
                 mock(TelephonySubscriptionSnapshot.class);
         doReturn(carrierConfig).when(lastSnapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
 
-        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+        checkGetRestrictedTransportsFromCarrierConfig(
+                TEST_UUID_2, lastSnapshot, restrictedTransports);
     }
 
     @Test
-    public void testGetRestrictedTransports_noCarrierConfig() {
+    public void testGetRestrictedTransportsFromCarrierConfig_noCarrierConfig() {
         final Set<Integer> restrictedTransports = Collections.singleton(TRANSPORT_WIFI);
 
         final TelephonySubscriptionSnapshot lastSnapshot =
                 mock(TelephonySubscriptionSnapshot.class);
 
-        checkGetRestrictedTransports(TEST_UUID_2, lastSnapshot, restrictedTransports);
+        checkGetRestrictedTransportsFromCarrierConfig(
+                TEST_UUID_2, lastSnapshot, restrictedTransports);
+    }
+
+    @Test
+    public void testGetRestrictedTransportsFromCarrierConfigAndVcnConfig() {
+        // Configure restricted transport in CarrierConfig
+        final Set<Integer> restrictedTransportInCarrierConfig =
+                Collections.singleton(TRANSPORT_WIFI);
+
+        PersistableBundle carrierConfigBundle = new PersistableBundle();
+        carrierConfigBundle.putIntArray(
+                VCN_RESTRICTED_TRANSPORTS_INT_ARRAY_KEY,
+                restrictedTransportInCarrierConfig.stream().mapToInt(i -> i).toArray());
+        final PersistableBundleWrapper carrierConfig =
+                new PersistableBundleWrapper(carrierConfigBundle);
+
+        final TelephonySubscriptionSnapshot lastSnapshot =
+                mock(TelephonySubscriptionSnapshot.class);
+        doReturn(carrierConfig).when(lastSnapshot).getCarrierConfigForSubGrp(eq(TEST_UUID_2));
+
+        // Configure restricted transport in VcnConfig
+        final Context mockContext = mock(Context.class);
+        doReturn(TEST_PACKAGE_NAME).when(mockContext).getOpPackageName();
+        final VcnConfig vcnConfig =
+                VcnConfigTest.buildTestConfig(
+                        mockContext, Collections.singleton(TRANSPORT_CELLULAR));
+
+        // Verifications
+        final Set<Integer> expectedTransports = new ArraySet<>();
+        expectedTransports.add(TRANSPORT_CELLULAR);
+        expectedTransports.add(TRANSPORT_WIFI);
+
+        Set<Integer> result =
+                new VcnManagementService.Dependencies()
+                        .getRestrictedTransports(TEST_UUID_2, lastSnapshot, vcnConfig);
+        assertEquals(expectedTransports, result);
     }
 
     private void checkGetUnderlyingNetworkPolicy(
@@ -1103,7 +1141,7 @@
         if (isTransportRestricted) {
             restrictedTransports.add(transportType);
         }
-        doReturn(restrictedTransports).when(mMockDeps).getRestrictedTransports(any(), any());
+        doReturn(restrictedTransports).when(mMockDeps).getRestrictedTransports(any(), any(), any());
 
         final VcnUnderlyingNetworkPolicy policy =
                 startVcnAndGetPolicyForTransport(
@@ -1201,7 +1239,7 @@
     public void testGetUnderlyingNetworkPolicyCell_restrictWifi() throws Exception {
         doReturn(Collections.singleton(TRANSPORT_WIFI))
                 .when(mMockDeps)
-                .getRestrictedTransports(any(), any());
+                .getRestrictedTransports(any(), any(), any());
 
         setupSubscriptionAndStartVcn(TEST_SUBSCRIPTION_ID, TEST_UUID_2, true /* isVcnActive */);
 
@@ -1344,6 +1382,23 @@
     }
 
     @Test
+    public void testVcnConfigChangeUpdatesPolicyListener() throws Exception {
+        setupActiveSubscription(TEST_UUID_2);
+
+        mVcnMgmtSvc.setVcnConfig(TEST_UUID_2, TEST_VCN_CONFIG, TEST_PACKAGE_NAME);
+        mVcnMgmtSvc.addVcnUnderlyingNetworkPolicyListener(mMockPolicyListener);
+
+        final Context mockContext = mock(Context.class);
+        doReturn(TEST_PACKAGE_NAME).when(mockContext).getOpPackageName();
+        final VcnConfig vcnConfig =
+                VcnConfigTest.buildTestConfig(
+                        mockContext, Collections.singleton(TRANSPORT_CELLULAR));
+        mVcnMgmtSvc.setVcnConfig(TEST_UUID_2, vcnConfig, TEST_PACKAGE_NAME);
+
+        verify(mMockPolicyListener).onPolicyChanged();
+    }
+
+    @Test
     public void testRemoveVcnUpdatesPolicyListener() throws Exception {
         setupActiveSubscription(TEST_UUID_2);
 
@@ -1375,7 +1430,7 @@
         setupActiveSubscription(TEST_UUID_2);
 
         mVcnMgmtSvc.setVcnConfig(TEST_UUID_2, TEST_VCN_CONFIG, TEST_PACKAGE_NAME);
-        mVcnMgmtSvc.addVcnUnderlyingNetworkPolicyListenerForTest(mMockPolicyListener);
+        mVcnMgmtSvc.addVcnUnderlyingNetworkPolicyListener(mMockPolicyListener);
 
         final TelephonySubscriptionSnapshot snapshot =
                 buildSubscriptionSnapshot(