diff options
8 files changed, 252 insertions, 41 deletions
diff --git a/services/core/java/com/android/server/VcnManagementService.java b/services/core/java/com/android/server/VcnManagementService.java index 916bec27af39..4dce59f23a79 100644 --- a/services/core/java/com/android/server/VcnManagementService.java +++ b/services/core/java/com/android/server/VcnManagementService.java @@ -66,6 +66,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.concurrent.TimeUnit; /** @@ -291,8 +292,9 @@ public class VcnManagementService extends IVcnManagementService.Stub { @NonNull VcnContext vcnContext, @NonNull ParcelUuid subscriptionGroup, @NonNull VcnConfig config, - @NonNull TelephonySubscriptionSnapshot snapshot) { - return new Vcn(vcnContext, subscriptionGroup, config, snapshot); + @NonNull TelephonySubscriptionSnapshot snapshot, + @NonNull VcnSafemodeCallback safemodeCallback) { + return new Vcn(vcnContext, subscriptionGroup, config, snapshot, safemodeCallback); } /** Gets the subId indicated by the given {@link WifiInfo}. */ @@ -438,7 +440,12 @@ public class VcnManagementService extends IVcnManagementService.Stub { // TODO(b/176939047): Support multiple VCNs active at the same time, or limit to one active // VCN. - final Vcn newInstance = mDeps.newVcn(mVcnContext, subscriptionGroup, config, mLastSnapshot); + final VcnSafemodeCallbackImpl safemodeCallback = + new VcnSafemodeCallbackImpl(subscriptionGroup); + + final Vcn newInstance = + mDeps.newVcn( + mVcnContext, subscriptionGroup, config, mLastSnapshot, safemodeCallback); mVcns.put(subscriptionGroup, newInstance); // Now that a new VCN has started, notify all registered listeners to refresh their @@ -536,7 +543,7 @@ public class VcnManagementService extends IVcnManagementService.Stub { } } - /** Get current configuration list for testing purposes */ + /** Get current VCNs for testing purposes */ @VisibleForTesting(visibility = Visibility.PRIVATE) public Map<ParcelUuid, Vcn> getAllVcns() { synchronized (mLock) { @@ -638,8 +645,8 @@ public class VcnManagementService extends IVcnManagementService.Stub { synchronized (mLock) { ParcelUuid subGroup = mLastSnapshot.getGroupForSubId(subId); - // TODO(b/178140910): only mark the Network as VCN-managed if not in safe mode - if (mVcns.containsKey(subGroup)) { + Vcn vcn = mVcns.get(subGroup); + if (vcn != null && vcn.isActive()) { isVcnManagedNetwork = true; } } @@ -651,4 +658,31 @@ public class VcnManagementService extends IVcnManagementService.Stub { return new VcnUnderlyingNetworkPolicy(false /* isTearDownRequested */, networkCapabilities); } + + /** Callback for signalling when a Vcn has entered Safemode. */ + public interface VcnSafemodeCallback { + /** Called by a Vcn to signal that it has entered Safemode. */ + void onEnteredSafemode(); + } + + /** VcnSafemodeCallback is used by Vcns to notify VcnManagementService on entering Safemode. */ + private class VcnSafemodeCallbackImpl implements VcnSafemodeCallback { + @NonNull private final ParcelUuid mSubGroup; + + private VcnSafemodeCallbackImpl(@NonNull final ParcelUuid subGroup) { + mSubGroup = Objects.requireNonNull(subGroup, "Missing subGroup"); + } + + @Override + public void onEnteredSafemode() { + synchronized (mLock) { + // Ignore if this subscription group doesn't exist anymore + if (!mVcns.containsKey(mSubGroup)) { + return; + } + + notifyAllPolicyListenersLocked(); + } + } + } } diff --git a/services/core/java/com/android/server/vcn/Vcn.java b/services/core/java/com/android/server/vcn/Vcn.java index a82f239948ff..5ec527a7d6c4 100644 --- a/services/core/java/com/android/server/vcn/Vcn.java +++ b/services/core/java/com/android/server/vcn/Vcn.java @@ -29,6 +29,7 @@ import android.util.Slog; import com.android.internal.annotations.VisibleForTesting; import com.android.internal.annotations.VisibleForTesting.Visibility; +import com.android.server.VcnManagementService.VcnSafemodeCallback; import com.android.server.vcn.TelephonySubscriptionTracker.TelephonySubscriptionSnapshot; import java.util.Collections; @@ -37,6 +38,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; /** * Represents an single instance of a VCN. @@ -82,10 +84,19 @@ public class Vcn extends Handler { /** Triggers an immediate teardown of the entire Vcn, including GatewayConnections. */ private static final int MSG_CMD_TEARDOWN = MSG_CMD_BASE; + /** + * Causes this VCN to immediately enter Safemode. + * + * <p>Upon entering Safemode, the VCN will unregister its RequestListener, tear down all of its + * VcnGatewayConnections, and notify VcnManagementService that it is in Safemode. + */ + private static final int MSG_CMD_ENTER_SAFEMODE = MSG_CMD_BASE + 1; + @NonNull private final VcnContext mVcnContext; @NonNull private final ParcelUuid mSubscriptionGroup; @NonNull private final Dependencies mDeps; @NonNull private final VcnNetworkRequestListener mRequestListener; + @NonNull private final VcnSafemodeCallback mVcnSafemodeCallback; @NonNull private final Map<VcnGatewayConnectionConfig, VcnGatewayConnection> mVcnGatewayConnections = @@ -94,14 +105,33 @@ public class Vcn extends Handler { @NonNull private VcnConfig mConfig; @NonNull private TelephonySubscriptionSnapshot mLastSnapshot; - private boolean mIsRunning = true; + /** + * Whether this Vcn instance is active and running. + * + * <p>The value will be {@code true} while running. It will be {@code false} if the VCN has been + * shut down or has entered safe mode. + * + * <p>This AtomicBoolean is required in order to ensure consistency and correctness across + * multiple threads. Unlike the rest of the Vcn, this is queried synchronously on Binder threads + * from VcnManagementService, and therefore cannot rely on guarantees of running on the VCN + * Looper. + */ + // TODO(b/179429339): update when exiting safemode (when a new VcnConfig is provided) + private final AtomicBoolean mIsActive = new AtomicBoolean(true); public Vcn( @NonNull VcnContext vcnContext, @NonNull ParcelUuid subscriptionGroup, @NonNull VcnConfig config, - @NonNull TelephonySubscriptionSnapshot snapshot) { - this(vcnContext, subscriptionGroup, config, snapshot, new Dependencies()); + @NonNull TelephonySubscriptionSnapshot snapshot, + @NonNull VcnSafemodeCallback vcnSafemodeCallback) { + this( + vcnContext, + subscriptionGroup, + config, + snapshot, + vcnSafemodeCallback, + new Dependencies()); } @VisibleForTesting(visibility = Visibility.PRIVATE) @@ -110,10 +140,13 @@ public class Vcn extends Handler { @NonNull ParcelUuid subscriptionGroup, @NonNull VcnConfig config, @NonNull TelephonySubscriptionSnapshot snapshot, + @NonNull VcnSafemodeCallback vcnSafemodeCallback, @NonNull Dependencies deps) { super(Objects.requireNonNull(vcnContext, "Missing vcnContext").getLooper()); mVcnContext = vcnContext; mSubscriptionGroup = Objects.requireNonNull(subscriptionGroup, "Missing subscriptionGroup"); + mVcnSafemodeCallback = + Objects.requireNonNull(vcnSafemodeCallback, "Missing vcnSafemodeCallback"); mDeps = Objects.requireNonNull(deps, "Missing deps"); mRequestListener = new VcnNetworkRequestListener(); @@ -143,6 +176,11 @@ public class Vcn extends Handler { sendMessageAtFrontOfQueue(obtainMessage(MSG_CMD_TEARDOWN)); } + /** Synchronously checks whether this Vcn is active. */ + public boolean isActive() { + return mIsActive.get(); + } + /** Get current Gateways for testing purposes */ @VisibleForTesting(visibility = Visibility.PRIVATE) public Set<VcnGatewayConnection> getVcnGatewayConnections() { @@ -160,7 +198,7 @@ public class Vcn extends Handler { @Override public void handleMessage(@NonNull Message msg) { - if (!mIsRunning) { + if (!isActive()) { return; } @@ -177,6 +215,9 @@ public class Vcn extends Handler { case MSG_CMD_TEARDOWN: handleTeardown(); break; + case MSG_CMD_ENTER_SAFEMODE: + handleEnterSafemode(); + break; default: Slog.wtf(getLogTag(), "Unknown msg.what: " + msg.what); } @@ -198,7 +239,13 @@ public class Vcn extends Handler { gatewayConnection.teardownAsynchronously(); } - mIsRunning = false; + mIsActive.set(false); + } + + private void handleEnterSafemode() { + handleTeardown(); + + mVcnSafemodeCallback.onEnteredSafemode(); } private void handleNetworkRequested( @@ -233,7 +280,8 @@ public class Vcn extends Handler { mVcnContext, mSubscriptionGroup, mLastSnapshot, - gatewayConnectionConfig); + gatewayConnectionConfig, + new VcnGatewayStatusCallbackImpl()); mVcnGatewayConnections.put(gatewayConnectionConfig, vcnGatewayConnection); } } @@ -242,7 +290,7 @@ public class Vcn extends Handler { private void handleSubscriptionsChanged(@NonNull TelephonySubscriptionSnapshot snapshot) { mLastSnapshot = snapshot; - if (mIsRunning) { + if (isActive()) { for (VcnGatewayConnection gatewayConnection : mVcnGatewayConnections.values()) { gatewayConnection.updateSubscriptionSnapshot(mLastSnapshot); } @@ -271,6 +319,20 @@ public class Vcn extends Handler { return 52; } + /** Callback used for passing status signals from a VcnGatewayConnection to its managing Vcn. */ + @VisibleForTesting(visibility = Visibility.PACKAGE) + public interface VcnGatewayStatusCallback { + /** Called by a VcnGatewayConnection to indicate that it has entered Safemode. */ + void onEnteredSafemode(); + } + + private class VcnGatewayStatusCallbackImpl implements VcnGatewayStatusCallback { + @Override + public void onEnteredSafemode() { + sendMessage(obtainMessage(MSG_CMD_ENTER_SAFEMODE)); + } + } + /** External dependencies used by Vcn, for injection in tests */ @VisibleForTesting(visibility = Visibility.PRIVATE) public static class Dependencies { @@ -279,9 +341,14 @@ public class Vcn extends Handler { VcnContext vcnContext, ParcelUuid subscriptionGroup, TelephonySubscriptionSnapshot snapshot, - VcnGatewayConnectionConfig connectionConfig) { + VcnGatewayConnectionConfig connectionConfig, + VcnGatewayStatusCallback gatewayStatusCallback) { return new VcnGatewayConnection( - vcnContext, subscriptionGroup, snapshot, connectionConfig); + vcnContext, + subscriptionGroup, + snapshot, + connectionConfig, + gatewayStatusCallback); } } } diff --git a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java index 853bb4324f90..bd82fcc1447b 100644 --- a/services/core/java/com/android/server/vcn/VcnGatewayConnection.java +++ b/services/core/java/com/android/server/vcn/VcnGatewayConnection.java @@ -69,6 +69,7 @@ import com.android.internal.util.StateMachine; import com.android.server.vcn.TelephonySubscriptionTracker.TelephonySubscriptionSnapshot; import com.android.server.vcn.UnderlyingNetworkTracker.UnderlyingNetworkRecord; import com.android.server.vcn.UnderlyingNetworkTracker.UnderlyingNetworkTrackerCallback; +import com.android.server.vcn.Vcn.VcnGatewayStatusCallback; import java.io.IOException; import java.net.Inet4Address; @@ -412,6 +413,7 @@ public class VcnGatewayConnection extends StateMachine { @NonNull private final ParcelUuid mSubscriptionGroup; @NonNull private final UnderlyingNetworkTracker mUnderlyingNetworkTracker; @NonNull private final VcnGatewayConnectionConfig mConnectionConfig; + @NonNull private final VcnGatewayStatusCallback mGatewayStatusCallback; @NonNull private final Dependencies mDeps; @NonNull private final VcnUnderlyingNetworkTrackerCallback mUnderlyingNetworkTrackerCallback; @@ -487,8 +489,15 @@ public class VcnGatewayConnection extends StateMachine { @NonNull VcnContext vcnContext, @NonNull ParcelUuid subscriptionGroup, @NonNull TelephonySubscriptionSnapshot snapshot, - @NonNull VcnGatewayConnectionConfig connectionConfig) { - this(vcnContext, subscriptionGroup, snapshot, connectionConfig, new Dependencies()); + @NonNull VcnGatewayConnectionConfig connectionConfig, + @NonNull VcnGatewayStatusCallback gatewayStatusCallback) { + this( + vcnContext, + subscriptionGroup, + snapshot, + connectionConfig, + gatewayStatusCallback, + new Dependencies()); } @VisibleForTesting(visibility = Visibility.PRIVATE) @@ -497,11 +506,14 @@ public class VcnGatewayConnection extends StateMachine { @NonNull ParcelUuid subscriptionGroup, @NonNull TelephonySubscriptionSnapshot snapshot, @NonNull VcnGatewayConnectionConfig connectionConfig, + @NonNull VcnGatewayStatusCallback gatewayStatusCallback, @NonNull Dependencies deps) { super(TAG, Objects.requireNonNull(vcnContext, "Missing vcnContext").getLooper()); mVcnContext = vcnContext; mSubscriptionGroup = Objects.requireNonNull(subscriptionGroup, "Missing subscriptionGroup"); mConnectionConfig = Objects.requireNonNull(connectionConfig, "Missing connectionConfig"); + mGatewayStatusCallback = + Objects.requireNonNull(gatewayStatusCallback, "Missing gatewayStatusCallback"); mDeps = Objects.requireNonNull(deps, "Missing deps"); synchronized (mLock) { diff --git a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java index 86a15912b6b4..3e659d0bc128 100644 --- a/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java +++ b/tests/vcn/java/android/net/vcn/VcnGatewayConnectionConfigTest.java @@ -59,12 +59,17 @@ public class VcnGatewayConnectionConfigTest { // Public for use in VcnGatewayConnectionTest public static VcnGatewayConnectionConfig buildTestConfig() { + return buildTestConfigWithExposedCaps(EXPOSED_CAPS); + } + + // Public for use in VcnGatewayConnectionTest + public static VcnGatewayConnectionConfig buildTestConfigWithExposedCaps(int... exposedCaps) { final VcnGatewayConnectionConfig.Builder builder = new VcnGatewayConnectionConfig.Builder() .setRetryInterval(RETRY_INTERVALS_MS) .setMaxMtu(MAX_MTU); - for (int caps : EXPOSED_CAPS) { + for (int caps : exposedCaps) { builder.addExposedCapability(caps); } diff --git a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java index e32e1e831f83..485964487fda 100644 --- a/tests/vcn/java/com/android/server/VcnManagementServiceTest.java +++ b/tests/vcn/java/com/android/server/VcnManagementServiceTest.java @@ -66,6 +66,7 @@ import android.telephony.TelephonyManager; import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; +import com.android.server.VcnManagementService.VcnSafemodeCallback; import com.android.server.vcn.TelephonySubscriptionTracker; import com.android.server.vcn.Vcn; import com.android.server.vcn.VcnContext; @@ -142,6 +143,9 @@ public class VcnManagementServiceTest { private final TelephonySubscriptionTracker mSubscriptionTracker = mock(TelephonySubscriptionTracker.class); + private final ArgumentCaptor<VcnSafemodeCallback> mSafemodeCallbackCaptor = + ArgumentCaptor.forClass(VcnSafemodeCallback.class); + private final VcnManagementService mVcnMgmtSvc; private final IVcnUnderlyingNetworkPolicyListener mMockPolicyListener = @@ -184,7 +188,7 @@ public class VcnManagementServiceTest { doAnswer((invocation) -> { // Mock-within a doAnswer is safe, because it doesn't actually run nested. return mock(Vcn.class); - }).when(mMockDeps).newVcn(any(), any(), any(), any()); + }).when(mMockDeps).newVcn(any(), any(), any(), any(), any()); final PersistableBundle bundle = PersistableBundleUtils.fromMap( @@ -307,7 +311,7 @@ public class VcnManagementServiceTest { TelephonySubscriptionSnapshot snapshot = triggerSubscriptionTrackerCbAndGetSnapshot(Collections.singleton(TEST_UUID_1)); verify(mMockDeps) - .newVcn(eq(mVcnContext), eq(TEST_UUID_1), eq(TEST_VCN_CONFIG), eq(snapshot)); + .newVcn(eq(mVcnContext), eq(TEST_UUID_1), eq(TEST_VCN_CONFIG), eq(snapshot), any()); } @Test @@ -485,7 +489,8 @@ public class VcnManagementServiceTest { eq(mVcnContext), eq(TEST_UUID_2), eq(TEST_VCN_CONFIG), - eq(TelephonySubscriptionSnapshot.EMPTY_SNAPSHOT)); + eq(TelephonySubscriptionSnapshot.EMPTY_SNAPSHOT), + any()); // Verify Vcn is updated if it was previously started mVcnMgmtSvc.setVcnConfig(TEST_UUID_2, TEST_VCN_CONFIG, TEST_PACKAGE_NAME); @@ -634,4 +639,25 @@ public class VcnManagementServiceTest { verify(mMockPolicyListener).onPolicyChanged(); } + + @Test + public void testVcnSafemodeCallbackOnEnteredSafemode() throws Exception { + TelephonySubscriptionSnapshot snapshot = + triggerSubscriptionTrackerCbAndGetSnapshot(Collections.singleton(TEST_UUID_1)); + verify(mMockDeps) + .newVcn( + eq(mVcnContext), + eq(TEST_UUID_1), + eq(TEST_VCN_CONFIG), + eq(snapshot), + mSafemodeCallbackCaptor.capture()); + + mVcnMgmtSvc.addVcnUnderlyingNetworkPolicyListener(mMockPolicyListener); + + VcnSafemodeCallback safemodeCallback = mSafemodeCallbackCaptor.getValue(); + safemodeCallback.onEnteredSafemode(); + + assertFalse(mVcnMgmtSvc.getAllVcns().get(TEST_UUID_1).isActive()); + verify(mMockPolicyListener).onPolicyChanged(); + } } diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionDisconnectedStateTest.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionDisconnectedStateTest.java index fbaae6f534a9..8643d8a2ea8a 100644 --- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionDisconnectedStateTest.java +++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionDisconnectedStateTest.java @@ -45,7 +45,12 @@ public class VcnGatewayConnectionDisconnectedStateTest extends VcnGatewayConnect public void testEnterWhileNotRunningTriggersQuit() throws Exception { final VcnGatewayConnection vgc = new VcnGatewayConnection( - mVcnContext, TEST_SUB_GRP, TEST_SUBSCRIPTION_SNAPSHOT, mConfig, mDeps); + mVcnContext, + TEST_SUB_GRP, + TEST_SUBSCRIPTION_SNAPSHOT, + mConfig, + mGatewayStatusCallback, + mDeps); vgc.setIsRunning(false); vgc.transitionTo(vgc.mDisconnectedState); diff --git a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java index df1341cce20f..333b5b990dde 100644 --- a/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java +++ b/tests/vcn/java/com/android/server/vcn/VcnGatewayConnectionTestBase.java @@ -43,6 +43,7 @@ import android.os.test.TestLooper; import com.android.server.IpSecService; import com.android.server.vcn.TelephonySubscriptionTracker.TelephonySubscriptionSnapshot; +import com.android.server.vcn.Vcn.VcnGatewayStatusCallback; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -80,6 +81,7 @@ public class VcnGatewayConnectionTestBase { @NonNull protected final VcnNetworkProvider mVcnNetworkProvider; @NonNull protected final VcnContext mVcnContext; @NonNull protected final VcnGatewayConnectionConfig mConfig; + @NonNull protected final VcnGatewayStatusCallback mGatewayStatusCallback; @NonNull protected final VcnGatewayConnection.Dependencies mDeps; @NonNull protected final UnderlyingNetworkTracker mUnderlyingNetworkTracker; @@ -94,6 +96,7 @@ public class VcnGatewayConnectionTestBase { mVcnNetworkProvider = mock(VcnNetworkProvider.class); mVcnContext = mock(VcnContext.class); mConfig = VcnGatewayConnectionConfigTest.buildTestConfig(); + mGatewayStatusCallback = mock(VcnGatewayStatusCallback.class); mDeps = mock(VcnGatewayConnection.Dependencies.class); mUnderlyingNetworkTracker = mock(UnderlyingNetworkTracker.class); @@ -123,7 +126,12 @@ public class VcnGatewayConnectionTestBase { mGatewayConnection = new VcnGatewayConnection( - mVcnContext, TEST_SUB_GRP, TEST_SUBSCRIPTION_SNAPSHOT, mConfig, mDeps); + mVcnContext, + TEST_SUB_GRP, + TEST_SUBSCRIPTION_SNAPSHOT, + mConfig, + mGatewayStatusCallback, + mDeps); } protected IpSecTransform makeDummyIpSecTransform() throws Exception { diff --git a/tests/vcn/java/com/android/server/vcn/VcnTest.java b/tests/vcn/java/com/android/server/vcn/VcnTest.java index 0c1df763a08e..66cbf84619ab 100644 --- a/tests/vcn/java/com/android/server/vcn/VcnTest.java +++ b/tests/vcn/java/com/android/server/vcn/VcnTest.java @@ -16,22 +16,27 @@ package com.android.server.vcn; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import android.content.Context; import android.net.NetworkRequest; import android.net.vcn.VcnConfig; +import android.net.vcn.VcnGatewayConnectionConfig; import android.net.vcn.VcnGatewayConnectionConfigTest; import android.os.ParcelUuid; import android.os.test.TestLooper; +import com.android.server.VcnManagementService.VcnSafemodeCallback; import com.android.server.vcn.TelephonySubscriptionTracker.TelephonySubscriptionSnapshot; +import com.android.server.vcn.Vcn.VcnGatewayStatusCallback; import com.android.server.vcn.VcnNetworkProvider.NetworkRequestListener; import org.junit.Before; @@ -51,9 +56,13 @@ public class VcnTest { private VcnContext mVcnContext; private TelephonySubscriptionSnapshot mSubscriptionSnapshot; private VcnNetworkProvider mVcnNetworkProvider; + private VcnSafemodeCallback mVcnSafemodeCallback; private Vcn.Dependencies mDeps; + private ArgumentCaptor<VcnGatewayStatusCallback> mGatewayStatusCallbackCaptor; + private TestLooper mTestLooper; + private VcnGatewayConnectionConfig mGatewayConnectionConfig; private VcnConfig mConfig; private Vcn mVcn; @@ -63,6 +72,7 @@ public class VcnTest { mVcnContext = mock(VcnContext.class); mSubscriptionSnapshot = mock(TelephonySubscriptionSnapshot.class); mVcnNetworkProvider = mock(VcnNetworkProvider.class); + mVcnSafemodeCallback = mock(VcnSafemodeCallback.class); mDeps = mock(Vcn.Dependencies.class); mTestLooper = new TestLooper(); @@ -76,15 +86,26 @@ public class VcnTest { doAnswer((invocation) -> { // Mock-within a doAnswer is safe, because it doesn't actually run nested. return mock(VcnGatewayConnection.class); - }).when(mDeps).newVcnGatewayConnection(any(), any(), any(), any()); + }).when(mDeps).newVcnGatewayConnection(any(), any(), any(), any(), any()); - mConfig = - new VcnConfig.Builder(mContext) - .addGatewayConnectionConfig( - VcnGatewayConnectionConfigTest.buildTestConfig()) - .build(); + mGatewayStatusCallbackCaptor = ArgumentCaptor.forClass(VcnGatewayStatusCallback.class); - mVcn = new Vcn(mVcnContext, TEST_SUB_GROUP, mConfig, mSubscriptionSnapshot, mDeps); + final VcnConfig.Builder configBuilder = new VcnConfig.Builder(mContext); + for (final int capability : VcnGatewayConnectionConfigTest.EXPOSED_CAPS) { + configBuilder.addGatewayConnectionConfig( + VcnGatewayConnectionConfigTest.buildTestConfigWithExposedCaps(capability)); + } + configBuilder.addGatewayConnectionConfig(VcnGatewayConnectionConfigTest.buildTestConfig()); + mConfig = configBuilder.build(); + + mVcn = + new Vcn( + mVcnContext, + TEST_SUB_GROUP, + mConfig, + mSubscriptionSnapshot, + mVcnSafemodeCallback, + mDeps); } private NetworkRequestListener verifyAndGetRequestListener() { @@ -95,23 +116,22 @@ public class VcnTest { return mNetworkRequestListenerCaptor.getValue(); } - private NetworkRequest getNetworkRequestWithCapabilities(int[] networkCapabilities) { - final NetworkRequest.Builder builder = new NetworkRequest.Builder(); - for (final int netCapability : networkCapabilities) { - builder.addCapability(netCapability); + private void startVcnGatewayWithCapabilities( + NetworkRequestListener requestListener, int... netCapabilities) { + final NetworkRequest.Builder requestBuilder = new NetworkRequest.Builder(); + for (final int netCapability : netCapabilities) { + requestBuilder.addCapability(netCapability); } - return builder.build(); + + requestListener.onNetworkRequested(requestBuilder.build(), NETWORK_SCORE, PROVIDER_ID); + mTestLooper.dispatchAll(); } @Test public void testSubscriptionSnapshotUpdatesVcnGatewayConnections() { final NetworkRequestListener requestListener = verifyAndGetRequestListener(); - - requestListener.onNetworkRequested( - getNetworkRequestWithCapabilities(VcnGatewayConnectionConfigTest.EXPOSED_CAPS), - NETWORK_SCORE, - PROVIDER_ID); - mTestLooper.dispatchAll(); + startVcnGatewayWithCapabilities( + requestListener, VcnGatewayConnectionConfigTest.EXPOSED_CAPS); final Set<VcnGatewayConnection> gatewayConnections = mVcn.getVcnGatewayConnections(); assertFalse(gatewayConnections.isEmpty()); @@ -126,4 +146,38 @@ public class VcnTest { verify(gateway).updateSubscriptionSnapshot(eq(updatedSnapshot)); } } + + @Test + public void testGatewayEnteringSafemodeNotifiesVcn() { + final NetworkRequestListener requestListener = verifyAndGetRequestListener(); + for (final int capability : VcnGatewayConnectionConfigTest.EXPOSED_CAPS) { + startVcnGatewayWithCapabilities(requestListener, capability); + } + + // Each Capability in EXPOSED_CAPS was split into a separate VcnGatewayConnection in #setUp. + // Expect one VcnGatewayConnection per capability. + final int numExpectedGateways = VcnGatewayConnectionConfigTest.EXPOSED_CAPS.length; + + final Set<VcnGatewayConnection> gatewayConnections = mVcn.getVcnGatewayConnections(); + assertEquals(numExpectedGateways, gatewayConnections.size()); + verify(mDeps, times(numExpectedGateways)) + .newVcnGatewayConnection( + eq(mVcnContext), + eq(TEST_SUB_GROUP), + eq(mSubscriptionSnapshot), + any(), + mGatewayStatusCallbackCaptor.capture()); + + // Doesn't matter which callback this gets - any Gateway entering Safemode should shut down + // all Gateways + final VcnGatewayStatusCallback statusCallback = mGatewayStatusCallbackCaptor.getValue(); + statusCallback.onEnteredSafemode(); + mTestLooper.dispatchAll(); + + for (final VcnGatewayConnection gatewayConnection : gatewayConnections) { + verify(gatewayConnection).teardownAsynchronously(); + } + verify(mVcnNetworkProvider).unregisterListener(requestListener); + verify(mVcnSafemodeCallback).onEnteredSafemode(); + } } |