Merge "[Test] multicast forwarding from LLA/MLA address test" into main
diff --git a/common/FlaggedApi.bp b/common/FlaggedApi.bp
index c382e76..449d7ae 100644
--- a/common/FlaggedApi.bp
+++ b/common/FlaggedApi.bp
@@ -21,3 +21,11 @@
     srcs: ["flags.aconfig"],
     visibility: ["//packages/modules/Connectivity:__subpackages__"],
 }
+
+aconfig_declarations {
+    name: "nearby_flags",
+    package: "com.android.nearby.flags",
+    container: "system",
+    srcs: ["nearby_flags.aconfig"],
+    visibility: ["//packages/modules/Connectivity:__subpackages__"],
+}
diff --git a/common/nearby_flags.aconfig b/common/nearby_flags.aconfig
new file mode 100644
index 0000000..b957d33
--- /dev/null
+++ b/common/nearby_flags.aconfig
@@ -0,0 +1,9 @@
+package: "com.android.nearby.flags"
+container: "system"
+
+flag {
+    name: "powered_off_finding"
+    namespace: "nearby"
+    description: "Controls whether the Powered Off Finding feature is enabled"
+    bug: "307898240"
+}
diff --git a/framework-t/Android.bp b/framework-t/Android.bp
index 9203a3e..e40b55c 100644
--- a/framework-t/Android.bp
+++ b/framework-t/Android.bp
@@ -197,6 +197,7 @@
     ],
     aconfig_declarations: [
         "com.android.net.flags-aconfig",
+        "nearby_flags",
     ],
 }
 
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 1ea1815..915ec52 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -74,6 +74,7 @@
 import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.modules.utils.build.SdkLevel;
 
 import libcore.net.event.NetworkEventDispatcher;
 
@@ -6278,9 +6279,13 @@
     // Only the system server process and the network stack have access.
     @FlaggedApi(Flags.SUPPORT_IS_UID_NETWORKING_BLOCKED)
     @SystemApi(client = MODULE_LIBRARIES)
-    @RequiresApi(Build.VERSION_CODES.TIRAMISU)  // BPF maps were only mainlined in T
+    // Note b/326143935 kernel bug can trigger crash on some T device.
+    @RequiresApi(VERSION_CODES.UPSIDE_DOWN_CAKE)
     @RequiresPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK)
     public boolean isUidNetworkingBlocked(int uid, boolean isNetworkMetered) {
+        if (!SdkLevel.isAtLeastU()) {
+            Log.wtf(TAG, "isUidNetworkingBlocked is not supported on pre-U devices");
+        }
         final BpfNetMapsReader reader = BpfNetMapsReader.getInstance();
         // Note that before V, the data saver status in bpf is written by ConnectivityService
         // when receiving {@link #ACTION_RESTRICT_BACKGROUND_CHANGED}. Thus,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index fe9bbba..56202fd 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -118,6 +118,14 @@
     }
 
     /**
+     * Indicates whether {@link #NSD_KNOWN_ANSWER_SUPPRESSION} is enabled, including for testing.
+     */
+    public boolean isKnownAnswerSuppressionEnabled() {
+        return mIsKnownAnswerSuppressionEnabled
+                || isForceEnabledForTest(NSD_KNOWN_ANSWER_SUPPRESSION);
+    }
+
+    /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
     public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 96a59e2..ed0bde2 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -538,7 +538,7 @@
     }
 
     private boolean isTruncatedKnownAnswerPacket(MdnsPacket packet) {
-        if (!mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled
+        if (!mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()
                 // Should ignore the response packet.
                 || (packet.flags & MdnsConstants.FLAGS_RESPONSE) != 0) {
             return false;
@@ -745,7 +745,7 @@
             // RR TTL as known by the Multicast DNS responder, the responder MUST
             // send an answer so as to update the querier's cache before the record
             // becomes in danger of expiration.
-            if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled
+            if (mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()
                     && isKnownAnswer(info.record, knownAnswerRecords)) {
                 continue;
             }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
index a46be3b..db3845a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -145,7 +145,7 @@
     public void queueReply(@NonNull MdnsReplyInfo reply) {
         ensureRunningOnHandlerThread(mHandler);
 
-        if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled) {
+        if (mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()) {
             mDependencies.removeMessages(mHandler, MSG_SEND, reply.source);
 
             final MdnsReplyInfo queuingReply = mSrcReplies.remove(reply.source);
@@ -231,7 +231,7 @@
         @Override
         public void handleMessage(@NonNull Message msg) {
             final MdnsReplyInfo replyInfo;
-            if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled) {
+            if (mMdnsFeatureFlags.isKnownAnswerSuppressionEnabled()) {
                 // Retrieve the MdnsReplyInfo from the map via a source address, as the reply info
                 // will be combined or updated.
                 final InetSocketAddress source = (InetSocketAddress) msg.obj;
diff --git a/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java b/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java
new file mode 100644
index 0000000..8598ac4
--- /dev/null
+++ b/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.NetworkStats;
+
+import com.android.internal.annotations.GuardedBy;
+
+import java.time.Clock;
+import java.util.HashMap;
+import java.util.Objects;
+
+/**
+ * A thread-safe cache for storing and retrieving {@link NetworkStats.Entry} objects,
+ * with an adjustable expiry duration to manage data freshness.
+ */
+class TrafficStatsRateLimitCache {
+    private final Clock mClock;
+    private final long mExpiryDurationMs;
+
+    /**
+     * Constructs a new {@link TrafficStatsRateLimitCache} with the specified expiry duration.
+     *
+     * @param clock The {@link Clock} to use for determining timestamps.
+     * @param expiryDurationMs The expiry duration in milliseconds.
+     */
+    TrafficStatsRateLimitCache(@NonNull Clock clock, long expiryDurationMs) {
+        mClock = clock;
+        mExpiryDurationMs = expiryDurationMs;
+    }
+
+    private static class TrafficStatsCacheKey {
+        @Nullable
+        public final String iface;
+        public final int uid;
+
+        TrafficStatsCacheKey(@Nullable String iface, int uid) {
+            this.iface = iface;
+            this.uid = uid;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (!(o instanceof TrafficStatsCacheKey)) return false;
+            TrafficStatsCacheKey that = (TrafficStatsCacheKey) o;
+            return uid == that.uid && Objects.equals(iface, that.iface);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(iface, uid);
+        }
+    }
+
+    private static class TrafficStatsCacheValue {
+        public final long timestamp;
+        @NonNull
+        public final NetworkStats.Entry entry;
+
+        TrafficStatsCacheValue(long timestamp, NetworkStats.Entry entry) {
+            this.timestamp = timestamp;
+            this.entry = entry;
+        }
+    }
+
+    @GuardedBy("mMap")
+    private final HashMap<TrafficStatsCacheKey, TrafficStatsCacheValue> mMap = new HashMap<>();
+
+    /**
+     * Retrieves a {@link NetworkStats.Entry} from the cache, associated with the given key.
+     *
+     * @param iface The interface name to include in the cache key. Null if not applicable.
+     * @param uid The UID to include in the cache key. {@code UID_ALL} if not applicable.
+     * @return The cached {@link NetworkStats.Entry}, or null if not found or expired.
+     */
+    @Nullable
+    NetworkStats.Entry get(String iface, int uid) {
+        final TrafficStatsCacheKey key = new TrafficStatsCacheKey(iface, uid);
+        synchronized (mMap) { // Synchronize for thread-safety
+            final TrafficStatsCacheValue value = mMap.get(key);
+            if (value != null && !isExpired(value.timestamp)) {
+                return value.entry;
+            } else {
+                mMap.remove(key); // Remove expired entries
+                return null;
+            }
+        }
+    }
+
+    /**
+     * Stores a {@link NetworkStats.Entry} in the cache, associated with the given key.
+     *
+     * @param iface The interface name to include in the cache key. Null if not applicable.
+     * @param uid   The UID to include in the cache key. {@code UID_ALL} if not applicable.
+     * @param entry The {@link NetworkStats.Entry} to store in the cache.
+     */
+    void put(String iface, int uid, @NonNull final NetworkStats.Entry entry) {
+        Objects.requireNonNull(entry);
+        final TrafficStatsCacheKey key = new TrafficStatsCacheKey(iface, uid);
+        synchronized (mMap) { // Synchronize for thread-safety
+            mMap.put(key, new TrafficStatsCacheValue(mClock.millis(), entry));
+        }
+    }
+
+    /**
+     * Clear the cache.
+     */
+    void clear() {
+        synchronized (mMap) {
+            mMap.clear();
+        }
+    }
+
+    private boolean isExpired(long timestamp) {
+        return mClock.millis() > timestamp + mExpiryDurationMs;
+    }
+}
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index e6287bc..6839c22 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -114,7 +114,6 @@
 import static com.android.net.module.util.PermissionUtils.enforceNetworkStackPermissionOr;
 import static com.android.net.module.util.PermissionUtils.hasAnyPermissionOf;
 import static com.android.server.ConnectivityStatsLog.CONNECTIVITY_STATE_SAMPLE;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.REQUEST_RESTRICTED_WIFI;
 
 import android.Manifest;
@@ -257,6 +256,7 @@
 import android.stats.connectivity.ValidatedState;
 import android.sysprop.NetworkProperties;
 import android.system.ErrnoException;
+import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.ArrayMap;
@@ -377,6 +377,7 @@
 import java.util.TreeSet;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
 /**
@@ -1287,18 +1288,14 @@
     }
     private final LegacyTypeTracker mLegacyTypeTracker = new LegacyTypeTracker(this);
 
-    private final CarrierPrivilegesLostListenerImpl mCarrierPrivilegesLostListenerImpl =
-            new CarrierPrivilegesLostListenerImpl();
-
-    private class CarrierPrivilegesLostListenerImpl implements CarrierPrivilegesLostListener {
-        @Override
-        public void onCarrierPrivilegesLost(int uid) {
-            if (mRequestRestrictedWifiEnabled) {
-                mHandler.sendMessage(mHandler.obtainMessage(
-                        EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, 0 /* arg2 */));
-            }
+    @VisibleForTesting
+    void onCarrierPrivilegesLost(Integer uid, Integer subId) {
+        if (mRequestRestrictedWifiEnabled) {
+            mHandler.sendMessage(mHandler.obtainMessage(
+                    EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, subId));
         }
     }
+
     final LocalPriorityDump mPriorityDumper = new LocalPriorityDump();
     /**
      * Helper class which parses out priority arguments and dumps sections according to their
@@ -1357,11 +1354,6 @@
         }
     }
 
-    @VisibleForTesting
-    CarrierPrivilegesLostListener getCarrierPrivilegesLostListener() {
-        return mCarrierPrivilegesLostListenerImpl;
-    }
-
     /**
      * Dependencies of ConnectivityService, for injection in tests.
      */
@@ -1525,7 +1517,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 boolean requestRestrictedWifiEnabled,
-                @NonNull CarrierPrivilegesLostListener listener) {
+                @NonNull BiConsumer<Integer, Integer> listener) {
             if (isAtLeastT()) {
                 return new CarrierPrivilegeAuthenticator(
                         context, tm, requestRestrictedWifiEnabled, listener);
@@ -1813,7 +1805,7 @@
                 && mDeps.isFeatureEnabled(context, REQUEST_RESTRICTED_WIFI);
         mCarrierPrivilegeAuthenticator = mDeps.makeCarrierPrivilegeAuthenticator(
                 mContext, mTelephonyManager, mRequestRestrictedWifiEnabled,
-                mCarrierPrivilegesLostListenerImpl);
+                this::onCarrierPrivilegesLost);
 
         if (mDeps.isAtLeastU()
                 && mDeps
@@ -5401,6 +5393,13 @@
         return false;
     }
 
+    private int getSubscriptionIdFromNetworkCaps(@NonNull final NetworkCapabilities caps) {
+        if (mCarrierPrivilegeAuthenticator != null) {
+            return mCarrierPrivilegeAuthenticator.getSubIdFromNetworkCapabilities(caps);
+        }
+        return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+    }
+
     private void handleRegisterNetworkRequestWithIntent(@NonNull final Message msg) {
         final NetworkRequestInfo nri = (NetworkRequestInfo) (msg.obj);
         // handleRegisterNetworkRequestWithIntent() doesn't apply to multilayer requests.
@@ -6010,7 +6009,7 @@
             if (nm == null) return;
 
             if (request == CaptivePortal.APP_REQUEST_REEVALUATION_REQUIRED) {
-                hasNetworkStackPermission();
+                enforceNetworkStackPermission(mContext);
                 nm.forceReevaluation(mDeps.getCallingUid());
             }
         }
@@ -6492,7 +6491,7 @@
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
                 case EVENT_UID_CARRIER_PRIVILEGES_LOST:
-                    handleUidCarrierPrivilegesLost(msg.arg1);
+                    handleUidCarrierPrivilegesLost(msg.arg1, msg.arg2);
                     break;
             }
         }
@@ -9155,7 +9154,7 @@
         }
     }
 
-    private void handleUidCarrierPrivilegesLost(int uid) {
+    private void handleUidCarrierPrivilegesLost(int uid, int subId) {
         ensureRunningOnConnectivityServiceThread();
         // A NetworkRequest needs to be revoked when all the conditions are met
         //   1. It requests restricted network
@@ -9166,6 +9165,7 @@
             if ((nr.isRequest() || nr.isListen())
                     && !nr.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
                     && nr.getRequestorUid() == uid
+                    && getSubscriptionIdFromNetworkCaps(nr.networkCapabilities) == subId
                     && !hasConnectivityRestrictedNetworksPermission(uid, true)) {
                 declareNetworkRequestUnfulfillable(nr);
             }
@@ -9174,7 +9174,8 @@
         // A NetworkAgent's allowedUids may need to be updated if the app has lost
         // carrier config
         for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
-            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)) {
+            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)
+                    && getSubscriptionIdFromNetworkCaps(nai.networkCapabilities) == subId) {
                 final NetworkCapabilities nc = new NetworkCapabilities(nai.networkCapabilities);
                 NetworkAgentInfo.restrictCapabilitiesFromNetworkAgent(
                         nc,
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 533278e..04d0fc1 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -40,12 +40,13 @@
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.util.Log;
-import android.util.SparseIntArray;
+import android.util.SparseArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.HandlerExecutor;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.networkstack.apishim.TelephonyManagerShimImpl;
 import com.android.networkstack.apishim.common.TelephonyManagerShim;
@@ -55,6 +56,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Executor;
+import java.util.function.BiConsumer;
 
 /**
  * Tracks the uid of the carrier privileged app that provides the carrier config.
@@ -71,7 +73,8 @@
     private final TelephonyManagerShim mTelephonyManagerShim;
     private final TelephonyManager mTelephonyManager;
     @GuardedBy("mLock")
-    private final SparseIntArray mCarrierServiceUid = new SparseIntArray(2 /* initialCapacity */);
+    private final SparseArray<CarrierServiceUidWithSubId> mCarrierServiceUidWithSubId =
+            new SparseArray<>(2 /* initialCapacity */);
     @GuardedBy("mLock")
     private int mModemCount = 0;
     private final Object mLock = new Object();
@@ -81,14 +84,14 @@
     private final boolean mUseCallbacksForServiceChanged;
     private final boolean mRequestRestrictedWifiEnabled;
     @NonNull
-    private final CarrierPrivilegesLostListener mListener;
+    private final BiConsumer<Integer, Integer> mListener;
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final Dependencies deps,
             @NonNull final TelephonyManager t,
             @NonNull final TelephonyManagerShim telephonyManagerShim,
             final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
@@ -121,7 +124,7 @@
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t, final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         this(c, new Dependencies(), t, TelephonyManagerShimImpl.newInstance(t),
                 requestRestrictedWifiEnabled, listener);
     }
@@ -142,18 +145,6 @@
         }
     }
 
-    /**
-     * Listener interface to get a notification when the carrier App lost its privileges.
-     */
-    public interface CarrierPrivilegesLostListener {
-        /**
-         * Called when the carrier App lost its privileges.
-         *
-         * @param uid  The uid of the carrier app which has lost its privileges.
-         */
-        void onCarrierPrivilegesLost(int uid);
-    }
-
     private void simConfigChanged() {
         synchronized (mLock) {
             unregisterCarrierPrivilegesListeners();
@@ -163,6 +154,29 @@
         }
     }
 
+    private static class CarrierServiceUidWithSubId {
+        final int mUid;
+        final int mSubId;
+
+        CarrierServiceUidWithSubId(int uid, int subId) {
+            mUid = uid;
+            mSubId = subId;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (!(obj instanceof CarrierServiceUidWithSubId)) {
+                return false;
+            }
+            CarrierServiceUidWithSubId compare = (CarrierServiceUidWithSubId) obj;
+            return (mUid == compare.mUid && mSubId == compare.mSubId);
+        }
+
+        @Override
+        public int hashCode() {
+            return mUid * 31 + mSubId;
+        }
+    }
     private class PrivilegeListener implements CarrierPrivilegesListenerShim {
         public final int mLogicalSlot;
 
@@ -192,10 +206,17 @@
                 return;
             }
             synchronized (mLock) {
-                int oldUid = mCarrierServiceUid.get(mLogicalSlot);
-                mCarrierServiceUid.put(mLogicalSlot, carrierServiceUid);
-                if (oldUid != 0 && oldUid != carrierServiceUid) {
-                    mListener.onCarrierPrivilegesLost(oldUid);
+                CarrierServiceUidWithSubId oldPair =
+                        mCarrierServiceUidWithSubId.get(mLogicalSlot);
+                int subId = getSubId(mLogicalSlot);
+                mCarrierServiceUidWithSubId.put(
+                        mLogicalSlot,
+                        new CarrierServiceUidWithSubId(carrierServiceUid, subId));
+                if (oldPair != null
+                        && oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(mCarrierServiceUidWithSubId.get(mLogicalSlot))) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -218,10 +239,13 @@
     private void unregisterCarrierPrivilegesListeners() {
         for (PrivilegeListener carrierPrivilegesListener : mCarrierPrivilegesChangedListeners) {
             removeCarrierPrivilegesListener(carrierPrivilegesListener);
-            int oldUid = mCarrierServiceUid.get(carrierPrivilegesListener.mLogicalSlot);
-            mCarrierServiceUid.delete(carrierPrivilegesListener.mLogicalSlot);
-            if (oldUid != 0) {
-                mListener.onCarrierPrivilegesLost(oldUid);
+            CarrierServiceUidWithSubId oldPair =
+                    mCarrierServiceUidWithSubId.get(carrierPrivilegesListener.mLogicalSlot);
+            mCarrierServiceUidWithSubId.remove(carrierPrivilegesListener.mLogicalSlot);
+            if (oldPair != null
+                    && oldPair.mUid != Process.INVALID_UID
+                    && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID) {
+                mListener.accept(oldPair.mUid, oldPair.mSubId);
             }
         }
         mCarrierPrivilegesChangedListeners.clear();
@@ -259,7 +283,23 @@
      */
     public boolean isCarrierServiceUidForNetworkCapabilities(int callingUid,
             @NonNull NetworkCapabilities networkCapabilities) {
-        if (callingUid == Process.INVALID_UID) return false;
+        if (callingUid == Process.INVALID_UID) {
+            return false;
+        }
+        int subId = getSubIdFromNetworkCapabilities(networkCapabilities);
+        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) {
+            return false;
+        }
+        return callingUid == getCarrierServiceUidForSubId(subId);
+    }
+
+    /**
+     * Extract the SubscriptionId from the NetworkCapabilities.
+     *
+     * @param networkCapabilities the network capabilities which may contains the SubscriptionId.
+     * @return the SubscriptionId.
+     */
+    public int getSubIdFromNetworkCapabilities(@NonNull NetworkCapabilities networkCapabilities) {
         int subId;
         if (networkCapabilities.hasSingleTransportBesidesTest(TRANSPORT_CELLULAR)) {
             subId = getSubIdFromTelephonySpecifier(networkCapabilities.getNetworkSpecifier());
@@ -285,21 +325,42 @@
             Log.wtf(TAG, "NetworkCapabilities subIds are inconsistent between "
                     + "specifier/transportInfo and mSubIds : " + networkCapabilities);
         }
-        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) return false;
-        return callingUid == getCarrierServiceUidForSubId(subId);
+        return subId;
+    }
+
+    @VisibleForTesting
+    protected int getSubId(int slotIndex) {
+        if (SdkLevel.isAtLeastU()) {
+            return SubscriptionManager.getSubscriptionId(slotIndex);
+        } else {
+            SubscriptionManager sm = mContext.getSystemService(SubscriptionManager.class);
+            int[] subIds = sm.getSubscriptionIds(slotIndex);
+            if (subIds != null && subIds.length > 0) {
+                return subIds[0];
+            }
+            return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+        }
     }
 
     @VisibleForTesting
     void updateCarrierServiceUid() {
         synchronized (mLock) {
-            SparseIntArray oldCarrierServiceUid = mCarrierServiceUid.clone();
-            mCarrierServiceUid.clear();
+            SparseArray<CarrierServiceUidWithSubId> copy = mCarrierServiceUidWithSubId.clone();
+            mCarrierServiceUidWithSubId.clear();
             for (int i = 0; i < mModemCount; i++) {
-                mCarrierServiceUid.put(i, getCarrierServicePackageUidForSlot(i));
+                int subId = getSubId(i);
+                mCarrierServiceUidWithSubId.put(
+                        i,
+                        new CarrierServiceUidWithSubId(
+                                getCarrierServicePackageUidForSlot(i), subId));
             }
-            for (int i = 0; i < oldCarrierServiceUid.size(); i++) {
-                if (mCarrierServiceUid.indexOfValue(oldCarrierServiceUid.valueAt(i)) < 0) {
-                    mListener.onCarrierPrivilegesLost(oldCarrierServiceUid.valueAt(i));
+            for (int i = 0; i < copy.size(); ++i) {
+                CarrierServiceUidWithSubId oldPair = copy.valueAt(i);
+                CarrierServiceUidWithSubId newPair = mCarrierServiceUidWithSubId.get(copy.keyAt(i));
+                if (oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(newPair)) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -307,18 +368,17 @@
 
     @VisibleForTesting
     int getCarrierServiceUidForSubId(int subId) {
-        final int slotId = getSlotIndex(subId);
         synchronized (mLock) {
-            return mCarrierServiceUid.get(slotId, Process.INVALID_UID);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                if (mCarrierServiceUidWithSubId.valueAt(i).mSubId == subId) {
+                    return mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                }
+            }
+            return Process.INVALID_UID;
         }
     }
 
     @VisibleForTesting
-    protected int getSlotIndex(int subId) {
-        return SubscriptionManager.getSlotIndex(subId);
-    }
-
-    @VisibleForTesting
     int getUidForPackage(String pkgName) {
         if (pkgName == null) {
             return Process.INVALID_UID;
@@ -383,11 +443,12 @@
         pw.println("CarrierPrivilegeAuthenticator:");
         pw.println("mRequestRestrictedWifiEnabled = " + mRequestRestrictedWifiEnabled);
         synchronized (mLock) {
-            final int size = mCarrierServiceUid.size();
-            for (int i = 0; i < size; ++i) {
-                final int logicalSlot = mCarrierServiceUid.keyAt(i);
-                final int serviceUid = mCarrierServiceUid.valueAt(i);
-                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                final int logicalSlot = mCarrierServiceUidWithSubId.keyAt(i);
+                final int serviceUid = mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                final int subId = mCarrierServiceUidWithSubId.valueAt(i).mSubId;
+                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid
+                        + " : subId = " + subId);
             }
         }
     }
diff --git a/service/src/com/android/server/connectivity/SatelliteAccessController.java b/service/src/com/android/server/connectivity/SatelliteAccessController.java
index 0968aff..b53abce 100644
--- a/service/src/com/android/server/connectivity/SatelliteAccessController.java
+++ b/service/src/com/android/server/connectivity/SatelliteAccessController.java
@@ -26,8 +26,10 @@
 import android.os.Handler;
 import android.os.Process;
 import android.os.UserHandle;
+import android.os.UserManager;
 import android.util.ArraySet;
 import android.util.Log;
+import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 
@@ -44,13 +46,18 @@
  */
 public class SatelliteAccessController {
     private static final String TAG = SatelliteAccessController.class.getSimpleName();
-    private final PackageManager mPackageManager;
+    private final Context mContext;
     private final Dependencies mDeps;
     private final DefaultMessageRoleListener mDefaultMessageRoleListener;
+    private final UserManager mUserManager;
     private final Consumer<Set<Integer>> mCallback;
-    private final Set<Integer> mSatelliteNetworkPreferredUidCache = new ArraySet<>();
     private final Handler mConnectivityServiceHandler;
 
+    // At this sparseArray, Key is userId and values are uids of SMS apps that are allowed
+    // to use satellite network as fallback.
+    private final SparseArray<Set<Integer>> mAllUsersSatelliteNetworkFallbackUidCache =
+            new SparseArray<>();
+
     /**
      *  Monitor {@link android.app.role.OnRoleHoldersChangedListener#onRoleHoldersChanged(String,
      *  UserHandle)},
@@ -59,10 +66,10 @@
     private final class DefaultMessageRoleListener
             implements OnRoleHoldersChangedListener {
         @Override
-        public void onRoleHoldersChanged(String role, UserHandle user) {
+        public void onRoleHoldersChanged(String role, UserHandle userHandle) {
             if (RoleManager.ROLE_SMS.equals(role)) {
                 Log.i(TAG, "ROLE_SMS Change detected ");
-                onRoleSmsChanged();
+                onRoleSmsChanged(userHandle);
             }
         }
 
@@ -71,7 +78,7 @@
                 mDeps.addOnRoleHoldersChangedListenerAsUser(
                         mConnectivityServiceHandler::post, this, UserHandle.ALL);
             } catch (RuntimeException e) {
-                Log.e(TAG, "Could not register satellite controller listener due to " + e);
+                Log.wtf(TAG, "Could not register satellite controller listener due to " + e);
             }
         }
     }
@@ -89,9 +96,9 @@
             mRoleManager = context.getSystemService(RoleManager.class);
         }
 
-        /** See {@link RoleManager#getRoleHolders(String)} */
-        public List<String> getRoleHolders(String roleName) {
-            return mRoleManager.getRoleHolders(roleName);
+        /** See {@link RoleManager#getRoleHoldersAsUser(String, UserHandle)} */
+        public List<String> getRoleHoldersAsUser(String roleName, UserHandle userHandle) {
+            return mRoleManager.getRoleHoldersAsUser(roleName, userHandle);
         }
 
         /** See {@link RoleManager#addOnRoleHoldersChangedListenerAsUser} */
@@ -105,81 +112,107 @@
     SatelliteAccessController(@NonNull final Context c, @NonNull final Dependencies deps,
             Consumer<Set<Integer>> callback,
             @NonNull final Handler connectivityServiceInternalHandler) {
+        mContext = c;
         mDeps = deps;
-        mPackageManager = c.getPackageManager();
+        mUserManager = mContext.getSystemService(UserManager.class);
         mDefaultMessageRoleListener = new DefaultMessageRoleListener();
         mCallback = callback;
         mConnectivityServiceHandler = connectivityServiceInternalHandler;
     }
 
-    private void updateSatelliteNetworkPreferredUidListCache(List<String> packageNames) {
-        for (String packageName : packageNames) {
-            // Check if SATELLITE_COMMUNICATION permission is enabled for default sms application
-            // package before adding it part of satellite network preferred uid cache list.
-            if (isSatellitePermissionEnabled(packageName)) {
-                mSatelliteNetworkPreferredUidCache.add(getUidForPackage(packageName));
+    private Set<Integer> updateSatelliteNetworkFallbackUidListCache(List<String> packageNames,
+            @NonNull UserHandle userHandle) {
+        Set<Integer> fallbackUids = new ArraySet<>();
+        PackageManager pm =
+                mContext.createContextAsUser(userHandle, 0).getPackageManager();
+        if (pm != null) {
+            for (String packageName : packageNames) {
+                // Check if SATELLITE_COMMUNICATION permission is enabled for default sms
+                // application package before adding it part of satellite network fallback uid
+                // cache list.
+                if (isSatellitePermissionEnabled(pm, packageName)) {
+                    int uid = getUidForPackage(pm, packageName);
+                    if (uid != Process.INVALID_UID) {
+                        fallbackUids.add(uid);
+                    }
+                }
             }
+        } else {
+            Log.wtf(TAG, "package manager found null");
         }
+        return fallbackUids;
     }
 
     //Check if satellite communication is enabled for the package
-    private boolean isSatellitePermissionEnabled(String packageName) {
-        if (mPackageManager != null) {
-            return mPackageManager.checkPermission(
-                    Manifest.permission.SATELLITE_COMMUNICATION, packageName)
-                    == PackageManager.PERMISSION_GRANTED;
-        }
-        return false;
+    private boolean isSatellitePermissionEnabled(PackageManager packageManager,
+            String packageName) {
+        return packageManager.checkPermission(
+                Manifest.permission.SATELLITE_COMMUNICATION, packageName)
+                == PackageManager.PERMISSION_GRANTED;
     }
 
-    private int getUidForPackage(String pkgName) {
+    private int getUidForPackage(PackageManager packageManager, String pkgName) {
         if (pkgName == null) {
             return Process.INVALID_UID;
         }
         try {
-            if (mPackageManager != null) {
-                ApplicationInfo applicationInfo = mPackageManager.getApplicationInfo(pkgName, 0);
-                if (applicationInfo != null) {
-                    return applicationInfo.uid;
-                }
-            }
+            ApplicationInfo applicationInfo = packageManager.getApplicationInfo(pkgName, 0);
+            return applicationInfo.uid;
         } catch (PackageManager.NameNotFoundException exception) {
             Log.e(TAG, "Unable to find uid for package: " + pkgName);
         }
         return Process.INVALID_UID;
     }
 
-    //on Role sms change triggered by OnRoleHoldersChangedListener()
-    private void onRoleSmsChanged() {
-        final List<String> packageNames = getRoleSmsChangedPackageName();
-
-        // Create a new Set
-        Set<Integer> previousSatellitePreferredUid = new ArraySet<>(
-                mSatelliteNetworkPreferredUidCache);
-
-        mSatelliteNetworkPreferredUidCache.clear();
-
-        if (packageNames != null) {
-            Log.i(TAG, "role_sms_packages: " + packageNames);
-            // On Role change listener, update the satellite network preferred uid cache list
-            updateSatelliteNetworkPreferredUidListCache(packageNames);
-            Log.i(TAG, "satellite_preferred_uid: " + mSatelliteNetworkPreferredUidCache);
-        } else {
-            Log.wtf(TAG, "package name was found null");
+    // on Role sms change triggered by OnRoleHoldersChangedListener()
+    // TODO(b/326373613): using UserLifecycleListener, callback to be received when user removed for
+    // user delete scenario. This to be used to update uid list and ML Layer request can also be
+    // updated.
+    private void onRoleSmsChanged(@NonNull UserHandle userHandle) {
+        int userId = userHandle.getIdentifier();
+        if (userId == Process.INVALID_UID) {
+            Log.wtf(TAG, "Invalid User Id");
+            return;
         }
 
+        //Returns empty list if no package exists
+        final List<String> packageNames =
+                mDeps.getRoleHoldersAsUser(RoleManager.ROLE_SMS, userHandle);
+
+        // Store previous satellite fallback uid available
+        final Set<Integer> prevUidsForUser =
+                mAllUsersSatelliteNetworkFallbackUidCache.get(userId, new ArraySet<>());
+
+        Log.i(TAG, "currentUser : role_sms_packages: " + userId + " : " + packageNames);
+        final Set<Integer> newUidsForUser = !packageNames.isEmpty()
+                ? updateSatelliteNetworkFallbackUidListCache(packageNames, userHandle)
+                : new ArraySet<>();
+        Log.i(TAG, "satellite_fallback_uid: " + newUidsForUser);
+
         // on Role change, update the multilayer request at ConnectivityService with updated
-        // satellite network preferred uid cache list if changed or to revoke for previous default
-        // sms app
-        if (!mSatelliteNetworkPreferredUidCache.equals(previousSatellitePreferredUid)) {
-            Log.i(TAG, "update multi layer request");
-            mCallback.accept(mSatelliteNetworkPreferredUidCache);
+        // satellite network fallback uid cache list of multiple users as applicable
+        if (newUidsForUser.equals(prevUidsForUser)) {
+            return;
         }
+
+        mAllUsersSatelliteNetworkFallbackUidCache.put(userId, newUidsForUser);
+
+        // Merge all uids of multiple users available
+        Set<Integer> mergedSatelliteNetworkFallbackUidCache = new ArraySet<>();
+        for (int i = 0; i < mAllUsersSatelliteNetworkFallbackUidCache.size(); i++) {
+            mergedSatelliteNetworkFallbackUidCache.addAll(
+                    mAllUsersSatelliteNetworkFallbackUidCache.valueAt(i));
+        }
+        Log.i(TAG, "merged uid list for multi layer request : "
+                + mergedSatelliteNetworkFallbackUidCache);
+
+        // trigger multiple layer request for satellite network fallback of multi user uids
+        mCallback.accept(mergedSatelliteNetworkFallbackUidCache);
     }
 
-    private List<String> getRoleSmsChangedPackageName() {
+    private List<String> getRoleSmsChangedPackageName(UserHandle userHandle) {
         try {
-            return mDeps.getRoleHolders(RoleManager.ROLE_SMS);
+            return mDeps.getRoleHoldersAsUser(RoleManager.ROLE_SMS, userHandle);
         } catch (RuntimeException e) {
             Log.wtf(TAG, "Could not get package name at role sms change update due to: " + e);
             return null;
@@ -188,7 +221,16 @@
 
     /** Register OnRoleHoldersChangedListener */
     public void start() {
-        mConnectivityServiceHandler.post(this::onRoleSmsChanged);
+        mConnectivityServiceHandler.post(this::updateAllUserRoleSmsUids);
         mDefaultMessageRoleListener.register();
     }
+
+    private void updateAllUserRoleSmsUids() {
+        List<UserHandle> existingUsers = mUserManager.getUserHandles(true /* excludeDying */);
+        // Iterate through the user handles and obtain their uids with role sms and satellite
+        // communication permission
+        for (UserHandle userHandle : existingUsers) {
+            onRoleSmsChanged(userHandle);
+        }
+    }
 }
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ExternalPacketForwarder.kt b/staticlibs/testutils/devicetests/com/android/testutils/ExternalPacketForwarder.kt
new file mode 100644
index 0000000..36eb795
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ExternalPacketForwarder.kt
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.testutils
+
+import java.io.FileDescriptor
+
+class ExternalPacketForwarder(
+    srcFd: FileDescriptor,
+    mtu: Int,
+    dstFd: FileDescriptor,
+    forwardMap: Map<Int, Int>
+) : PacketForwarderBase(srcFd, mtu, dstFd, forwardMap) {
+
+    /**
+     * Prepares a packet for forwarding by potentially updating the
+     * source port based on the specified port remapping rules.
+     *
+     * @param buf The packet data as a byte array.
+     * @param version The IP version of the packet (e.g., 4 for IPv4).
+     */
+    override fun remapPort(buf: ByteArray, version: Int) {
+        val transportOffset = getTransportOffset(version)
+        val intPort = getRemappedPort(buf, transportOffset)
+
+        // Copy remapped source port.
+        if (intPort != 0) {
+            setPortAt(intPort, buf, transportOffset)
+        }
+   }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/InternalPacketForwarder.kt b/staticlibs/testutils/devicetests/com/android/testutils/InternalPacketForwarder.kt
new file mode 100644
index 0000000..58829dc
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/InternalPacketForwarder.kt
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.testutils
+
+import java.io.FileDescriptor
+
+class InternalPacketForwarder(
+    srcFd: FileDescriptor,
+    mtu: Int,
+    dstFd: FileDescriptor,
+    forwardMap: Map<Int, Int>
+) : PacketForwarderBase(srcFd, mtu, dstFd, forwardMap) {
+    /**
+     * Prepares a packet for forwarding by potentially updating the
+     * destination port based on the specified port remapping rules.
+     *
+     * @param buf The packet data as a byte array.
+     * @param version The IP version of the packet (e.g., 4 for IPv4).
+     */
+    override fun remapPort(buf: ByteArray, version: Int) {
+        val transportOffset = getTransportOffset(version) + DESTINATION_PORT_OFFSET
+        val extPort = getRemappedPort(buf, transportOffset)
+
+        // Copy remapped destination port.
+        if (extPort != 0) {
+            setPortAt(extPort, buf, transportOffset)
+        }
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt b/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
index 1a2cc88..0b736d1 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
@@ -40,7 +40,8 @@
 class PacketBridge(
     context: Context,
     addresses: List<LinkAddress>,
-    dnsAddr: InetAddress
+    dnsAddr: InetAddress,
+    portMapping: List<Pair<Int, Int>>
 ) {
     private val binder = Binder()
 
@@ -56,6 +57,10 @@
     // Register test networks to ConnectivityService.
     private val internalNetworkCallback: TestableNetworkCallback
     private val externalNetworkCallback: TestableNetworkCallback
+
+    private val internalForwardMap = HashMap<Int, Int>()
+    private val externalForwardMap = HashMap<Int, Int>()
+
     val internalNetwork: Network
     val externalNetwork: Network
     init {
@@ -65,14 +70,28 @@
         externalNetworkCallback = exCb
         internalNetwork = inNet
         externalNetwork = exNet
+        for (mapping in portMapping) {
+            internalForwardMap[mapping.first] = mapping.second
+            externalForwardMap[mapping.second] = mapping.first
+        }
     }
 
     // Set up the packet bridge.
     private val internalFd = internalIface.fileDescriptor.fileDescriptor
     private val externalFd = externalIface.fileDescriptor.fileDescriptor
 
-    private val pr1 = PacketForwarder(internalFd, 1500, externalFd)
-    private val pr2 = PacketForwarder(externalFd, 1500, internalFd)
+    private val pr1 = InternalPacketForwarder(
+        internalFd,
+        1500,
+        externalFd,
+        internalForwardMap
+    )
+    private val pr2 = ExternalPacketForwarder(
+        externalFd,
+        1500,
+        internalFd,
+        externalForwardMap
+    )
 
     fun start() {
         IoUtils.setBlocking(internalFd, true /* blocking */)
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java b/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarderBase.java
similarity index 68%
rename from staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java
rename to staticlibs/testutils/devicetests/com/android/testutils/PacketForwarderBase.java
index d8efb7d..5c79eb0 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarderBase.java
@@ -32,6 +32,7 @@
 
 import java.io.FileDescriptor;
 import java.io.IOException;
+import java.util.Map;
 import java.util.Objects;
 
 /**
@@ -57,8 +58,9 @@
  * from the http server, the same mechanism is applied but in a different direction,
  * where the source and destination will be swapped.
  */
-public class PacketForwarder extends Thread {
+public abstract class PacketForwarderBase extends Thread {
     private static final String TAG = "PacketForwarder";
+    static final int DESTINATION_PORT_OFFSET = 2;
 
     // The source fd to read packets from.
     @NonNull
@@ -70,8 +72,10 @@
     @NonNull
     final FileDescriptor mDstFd;
 
+    @NonNull
+    final Map<Integer, Integer> mPortRemapRules;
     /**
-     * Construct a {@link PacketForwarder}.
+     * Construct a {@link PacketForwarderBase}.
      *
      * This class reads packets from {@code srcFd} of a {@link TestNetworkInterface}, and
      * forwards them to the {@code dstFd} of another {@link TestNetworkInterface}.
@@ -82,13 +86,49 @@
      * @param srcFd   {@link FileDescriptor} to read packets from.
      * @param mtu     MTU of the test network.
      * @param dstFd   {@link FileDescriptor} to write packets to.
+     * @param portRemapRules    port remap rules
      */
-    public PacketForwarder(@NonNull FileDescriptor srcFd, int mtu,
-                           @NonNull FileDescriptor dstFd) {
+    public PacketForwarderBase(@NonNull FileDescriptor srcFd, int mtu,
+                           @NonNull FileDescriptor dstFd,
+                           @NonNull Map<Integer, Integer> portRemapRules) {
         super(TAG);
         mSrcFd = Objects.requireNonNull(srcFd);
         mBuf = new byte[mtu];
         mDstFd = Objects.requireNonNull(dstFd);
+        mPortRemapRules = Objects.requireNonNull(portRemapRules);
+    }
+
+    /**
+     * A method to prepare forwarding packets between two instances of {@link TestNetworkInterface},
+     * which includes ports mapping.
+     * Subclasses should override this method to implement the needed port remapping.
+     * For internal forwarder will remapped destination port,
+     * external forwarder will remapped source port.
+     * Example:
+     * An outgoing packet from the internal interface with
+     * source 1.2.3.4:1234 and destination 8.8.8.8:80
+     * might be translated to 8.8.8.8:1234 -> 1.2.3.4:8080 before forwarding.
+     * An outgoing packet from the external interface with
+     * source 1.2.3.4:8080 and destination 8.8.8.8:1234
+     * might be translated to 8.8.8.8:80 -> 1.2.3.4:1234 before forwarding.
+     */
+    abstract void remapPort(@NonNull byte[] buf, int version);
+
+    /**
+     * Retrieves a potentially remapped port number from a packet.
+     *
+     * @param buf            The packet data as a byte array.
+     * @param transportOffset The offset within the packet where the transport layer port begins.
+     * @return The remapped port if a mapping exists in the internal forwarding map,
+     *         otherwise returns 0 (indicating no remapping).
+     */
+    int getRemappedPort(@NonNull byte[] buf, int transportOffset) {
+        int port = PacketReflectorUtil.getPortAt(buf, transportOffset);
+        return mPortRemapRules.getOrDefault(port, 0);
+    }
+
+    int getTransportOffset(int version) {
+        return version == 4 ? IPV4_HEADER_LENGTH : IPV6_HEADER_LENGTH;
     }
 
     private void forwardPacket(@NonNull byte[] buf, int len) {
@@ -99,7 +139,13 @@
         }
     }
 
-    // Reads one packet from mSrcFd, and writes the packet to the mDstFd for supported protocols.
+    /**
+     * Reads one packet from mSrcFd, and writes the packet to the mDestFd for supported protocols.
+     * This includes:
+     * 1.Address Swapping: Swaps source and destination IP addresses.
+     * 2.Port Remapping: Remap port if necessary.
+     * 3.Checksum Recalculation: Updates IP and transport layer checksums to reflect changes.
+     */
     private void processPacket() {
         final int len = PacketReflectorUtil.readPacket(mSrcFd, mBuf);
         if (len < 1) {
@@ -142,13 +188,19 @@
         if (len < ipHdrLen + transportHdrLen) {
             throw new IllegalStateException("Unexpected buffer length: " + len);
         }
-        // Swap addresses.
+
+        // Swap source and destination address.
         PacketReflectorUtil.swapAddresses(mBuf, version);
 
+        // Remapping the port.
+        remapPort(mBuf, version);
+
+        // Fix IP and Transport layer checksum.
+        PacketReflectorUtil.fixPacketChecksum(mBuf, len, version, proto);
+
         // Send the packet to the destination fd.
         forwardPacket(mBuf, len);
     }
-
     @Override
     public void run() {
         Log.i(TAG, "starting fd=" + mSrcFd + " valid=" + mSrcFd.valid());
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestHttpServer.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestHttpServer.kt
index 740bf63..f1f0c1c 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestHttpServer.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestHttpServer.kt
@@ -25,8 +25,10 @@
  * A minimal HTTP server running on a random available port.
  *
  * @param host The host to listen to, or null to listen on all hosts
+ * @param port The port to listen to, or 0 to auto select
  */
-class TestHttpServer(host: String? = null) : NanoHTTPD(host, 0 /* auto-select the port */) {
+class TestHttpServer
+    @JvmOverloads constructor(host: String? = null, port: Int = 0) : NanoHTTPD(host, port) {
     // Map of URL path -> HTTP response code
     private val responses = HashMap<Request, Response>()
 
diff --git a/tests/cts/net/src/android/net/cts/DnsResolverTest.java b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
index 9ff0f2f..752891f 100644
--- a/tests/cts/net/src/android/net/cts/DnsResolverTest.java
+++ b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
@@ -23,6 +23,7 @@
 import static android.net.DnsResolver.TYPE_AAAA;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
+import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
 import static android.system.OsConstants.ETIMEDOUT;
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
@@ -59,11 +60,14 @@
 import com.android.net.module.util.DnsPacket;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
+import com.android.testutils.DeviceConfigRule;
 import com.android.testutils.DnsResolverModuleTest;
 import com.android.testutils.SkipPresubmit;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -80,6 +84,8 @@
 @AppModeFull(reason = "WRITE_SECURE_SETTINGS permission can't be granted to instant apps")
 @RunWith(AndroidJUnit4.class)
 public class DnsResolverTest {
+    @ClassRule
+    public static final DeviceConfigRule DEVICE_CONFIG_CLASS_RULE = new DeviceConfigRule();
     @Rule
     public final DevSdkIgnoreRule ignoreRule = new DevSdkIgnoreRule();
 
@@ -123,6 +129,20 @@
 
     private TestNetworkCallback mWifiRequestCallback = null;
 
+    /**
+     * @see BeforeClass
+     */
+    @BeforeClass
+    public static void beforeClass() throws Exception {
+        // Use async private DNS resolution to avoid flakes due to races applying the setting
+        DEVICE_CONFIG_CLASS_RULE.setConfig(NAMESPACE_CONNECTIVITY,
+                "networkmonitor_async_privdns_resolution", "1");
+        // Make sure NetworkMonitor is restarted before and after the test so the flag is applied
+        // and cleaned up.
+        maybeToggleWifiAndCell();
+        DEVICE_CONFIG_CLASS_RULE.runAfterNextCleanup(DnsResolverTest::maybeToggleWifiAndCell);
+    }
+
     @Before
     public void setUp() throws Exception {
         mContext = InstrumentationRegistry.getContext();
@@ -144,6 +164,12 @@
         }
     }
 
+    private static void maybeToggleWifiAndCell() throws Exception {
+        final CtsNetUtils utils = new CtsNetUtils(InstrumentationRegistry.getContext());
+        utils.reconnectWifiIfSupported();
+        utils.reconnectCellIfSupported();
+    }
+
     private static String byteArrayToHexString(byte[] bytes) {
         char[] hexChars = new char[bytes.length * 2];
         for (int i = 0; i < bytes.length; ++i) {
diff --git a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
index 17a9ca2..bca18f5 100644
--- a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
+++ b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
@@ -17,6 +17,12 @@
 package android.net.cts;
 
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
 
 import android.content.ContentResolver;
 import android.content.Context;
@@ -28,9 +34,21 @@
 import android.platform.test.annotations.AppModeFull;
 import android.system.ErrnoException;
 import android.system.OsConstants;
-import android.test.AndroidTestCase;
 
-public class MultinetworkApiTest extends AndroidTestCase {
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.platform.app.InstrumentationRegistry;
+
+import com.android.testutils.DeviceConfigRule;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+public class MultinetworkApiTest {
+    @Rule
+    public final DeviceConfigRule mDeviceConfigRule = new DeviceConfigRule();
 
     static {
         System.loadLibrary("nativemultinetwork_jni");
@@ -58,20 +76,17 @@
     private CtsNetUtils mCtsNetUtils;
     private String mOldMode;
     private String mOldDnsSpecifier;
+    private Context mContext;
 
-    @Override
-    protected void setUp() throws Exception {
-        super.setUp();
-        mCM = (ConnectivityManager) getContext().getSystemService(Context.CONNECTIVITY_SERVICE);
-        mCR = getContext().getContentResolver();
-        mCtsNetUtils = new CtsNetUtils(getContext());
+    @Before
+    public void setUp() throws Exception {
+        mContext = InstrumentationRegistry.getInstrumentation().getContext();
+        mCM = mContext.getSystemService(ConnectivityManager.class);
+        mCR = mContext.getContentResolver();
+        mCtsNetUtils = new CtsNetUtils(mContext);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
-        super.tearDown();
-    }
-
+    @Test
     public void testGetaddrinfo() throws ErrnoException {
         for (Network network : mCtsNetUtils.getTestableNetworks()) {
             int errno = runGetaddrinfoCheck(network.getNetworkHandle());
@@ -82,6 +97,7 @@
         }
     }
 
+    @Test
     @AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
     public void testSetprocnetwork() throws ErrnoException {
         // Hopefully no prior test in this process space has set a default network.
@@ -125,6 +141,7 @@
         }
     }
 
+    @Test
     @AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
     public void testSetsocknetwork() throws ErrnoException {
         for (Network network : mCtsNetUtils.getTestableNetworks()) {
@@ -136,6 +153,7 @@
         }
     }
 
+    @Test
     public void testNativeDatagramTransmission() throws ErrnoException {
         for (Network network : mCtsNetUtils.getTestableNetworks()) {
             int errno = runDatagramCheck(network.getNetworkHandle());
@@ -146,6 +164,7 @@
         }
     }
 
+    @Test
     public void testNoSuchNetwork() {
         final Network eNoNet = new Network(54321);
         assertNull(mCM.getNetworkInfo(eNoNet));
@@ -158,6 +177,7 @@
         // assertEquals(-OsConstants.ENONET, runGetaddrinfoCheck(eNoNetHandle));
     }
 
+    @Test
     public void testNetworkHandle() {
         // Test Network -> NetworkHandle -> Network results in the same Network.
         for (Network network : mCtsNetUtils.getTestableNetworks()) {
@@ -181,6 +201,7 @@
         } catch (IllegalArgumentException e) {}
     }
 
+    @Test
     public void testResNApi() throws Exception {
         final Network[] testNetworks = mCtsNetUtils.getTestableNetworks();
 
@@ -201,9 +222,21 @@
         }
     }
 
+    @Test
     @AppModeFull(reason = "WRITE_SECURE_SETTINGS permission can't be granted to instant apps")
-    public void testResNApiNXDomainPrivateDns() throws InterruptedException {
+    public void testResNApiNXDomainPrivateDns() throws Exception {
+        // Use async private DNS resolution to avoid flakes due to races applying the setting
+        mDeviceConfigRule.setConfig(NAMESPACE_CONNECTIVITY,
+                "networkmonitor_async_privdns_resolution", "1");
+        mCtsNetUtils.reconnectWifiIfSupported();
+        mCtsNetUtils.reconnectCellIfSupported();
+
         mCtsNetUtils.storePrivateDnsSetting();
+
+        mDeviceConfigRule.runAfterNextCleanup(() -> {
+            mCtsNetUtils.reconnectWifiIfSupported();
+            mCtsNetUtils.reconnectCellIfSupported();
+        });
         // Enable private DNS strict mode and set server to dns.google before doing NxDomain test.
         // b/144521720
         try {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 6db372f..ce2c2c1 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -1716,6 +1716,177 @@
         }
     }
 
+    @Test
+    fun testReplyWhenKnownAnswerSuppressionFlagSet() {
+        // The flag may be removed in the future but known-answer suppression should be enabled by
+        // default in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_known_answer_suppression", "1")
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_unicast_reply_enabled", "1")
+
+        val si = makeTestServiceInfo(testNetwork1.network)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        var nsResponder: NSResponder? = null
+        tryTest {
+            registerService(registrationRecord, si)
+            val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                    testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+            packetReader.startAsyncForTest()
+
+            handlerThread.waitForIdle(TIMEOUT_MS)
+            /*
+            Send a query with a known answer. Expect to receive a response containing TXT record
+            only.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd =
+                    scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR',
+                            qclass=0x8001) /
+                    scapy.DNSQR(qname='NsdTest123456789._nmt123456789._tcp.local', qtype='TXT',
+                            qclass=0x8001),
+                    an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=4500,
+                            rdata='NsdTest123456789._nmt123456789._tcp.local')
+            )).hex()
+            */
+            val query = HexDump.hexStringToByteArray("0000000000020001000000000d5f6e6d74313233343" +
+                    "536373839045f746370056c6f63616c00000c8001104e7364546573743132333435363738390" +
+                    "d5f6e6d74313233343536373839045f746370056c6f63616c00001080010d5f6e6d743132333" +
+                    "43536373839045f746370056c6f63616c00000c000100001194002b104e73645465737431323" +
+                    "33435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+            replaceServiceNameAndTypeWithTestSuffix(query)
+
+            val testSrcAddr = makeLinkLocalAddressOfOtherDeviceOnPrefix(testNetwork1.network)
+            nsResponder = NSResponder(packetReader, mapOf(
+                    testSrcAddr to MacAddress.fromString("01:02:03:04:05:06")
+            )).apply { start() }
+
+            packetReader.sendResponse(buildMdnsPacket(query, testSrcAddr))
+            // The reply is sent unicast to the source address. There may be announcements sent
+            // multicast around this time, so filter by destination address.
+            val reply = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isReplyFor("$serviceName.$serviceType.local", DnsResolver.TYPE_TXT) &&
+                        !pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.dstAddr == testSrcAddr
+            }
+            assertNotNull(reply)
+
+            /*
+            Send a query with a known answer (TTL is less than half). Expect to receive a response
+            containing both PTR and TXT records.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd =
+                    scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR',
+                            qclass=0x8001) /
+                    scapy.DNSQR(qname='NsdTest123456789._nmt123456789._tcp.local', qtype='TXT',
+                            qclass=0x8001),
+                    an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=2150,
+                            rdata='NsdTest123456789._nmt123456789._tcp.local')
+            )).hex()
+            */
+            val query2 = HexDump.hexStringToByteArray("0000000000020001000000000d5f6e6d7431323334" +
+                    "3536373839045f746370056c6f63616c00000c8001104e736454657374313233343536373839" +
+                    "0d5f6e6d74313233343536373839045f746370056c6f63616c00001080010d5f6e6d74313233" +
+                    "343536373839045f746370056c6f63616c00000c000100000866002b104e7364546573743132" +
+                    "333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+            replaceServiceNameAndTypeWithTestSuffix(query2)
+
+            packetReader.sendResponse(buildMdnsPacket(query2, testSrcAddr))
+            // The reply is sent unicast to the source address. There may be announcements sent
+            // multicast around this time, so filter by destination address.
+            val reply2 = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isReplyFor("$serviceName.$serviceType.local", DnsResolver.TYPE_TXT) &&
+                        pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.dstAddr == testSrcAddr
+            }
+            assertNotNull(reply2)
+        } cleanup {
+            nsResponder?.stop()
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        }
+    }
+
+    @Test
+    fun testReplyWithMultipacketWhenKnownAnswerSuppressionFlagSet() {
+        // The flag may be removed in the future but known-answer suppression should be enabled by
+        // default in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_known_answer_suppression", "1")
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_unicast_reply_enabled", "1")
+
+        val si = makeTestServiceInfo(testNetwork1.network)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        var nsResponder: NSResponder? = null
+        tryTest {
+            registerService(registrationRecord, si)
+            val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                    testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+            packetReader.startAsyncForTest()
+
+            handlerThread.waitForIdle(TIMEOUT_MS)
+            /*
+            Send a query with truncated bit set.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, tc=1, qd=
+                    scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR',
+                            qclass=0x8001) /
+                    scapy.DNSQR(qname='NsdTest123456789._nmt123456789._tcp.local', qtype='TXT',
+                            qclass=0x8001)
+            )).hex()
+            */
+            val query = HexDump.hexStringToByteArray("0000020000020000000000000d5f6e6d74313233343" +
+                    "536373839045f746370056c6f63616c00000c8001104e7364546573743132333435363738390" +
+                    "d5f6e6d74313233343536373839045f746370056c6f63616c0000108001")
+            replaceServiceNameAndTypeWithTestSuffix(query)
+            /*
+            Send a known answer packet (other service) with truncated bit set.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, tc=1, qd=None,
+                    an = scapy.DNSRR(rrname='_test._tcp.local', type='PTR', ttl=4500,
+                            rdata='NsdTest._test._tcp.local')
+            )).hex()
+            */
+            val knownAnswer1 = HexDump.hexStringToByteArray("000002000000000100000000055f74657374" +
+                    "045f746370056c6f63616c00000c000100001194001a074e736454657374055f74657374045f" +
+                    "746370056c6f63616c00")
+            replaceServiceNameAndTypeWithTestSuffix(knownAnswer1)
+            /*
+            Send a known answer packet.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd=None,
+                    an = scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=4500,
+                            rdata='NsdTest123456789._nmt123456789._tcp.local')
+            )).hex()
+            */
+            val knownAnswer2 = HexDump.hexStringToByteArray("0000000000000001000000000d5f6e6d7431" +
+                    "3233343536373839045f746370056c6f63616c00000c000100001194002b104e736454657374" +
+                    "3132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+            replaceServiceNameAndTypeWithTestSuffix(knownAnswer2)
+
+            val testSrcAddr = makeLinkLocalAddressOfOtherDeviceOnPrefix(testNetwork1.network)
+            nsResponder = NSResponder(packetReader, mapOf(
+                    testSrcAddr to MacAddress.fromString("01:02:03:04:05:06")
+            )).apply { start() }
+
+            packetReader.sendResponse(buildMdnsPacket(query, testSrcAddr))
+            packetReader.sendResponse(buildMdnsPacket(knownAnswer1, testSrcAddr))
+            packetReader.sendResponse(buildMdnsPacket(knownAnswer2, testSrcAddr))
+            // The reply is sent unicast to the source address. There may be announcements sent
+            // multicast around this time, so filter by destination address.
+            val reply = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isReplyFor("$serviceName.$serviceType.local", DnsResolver.TYPE_TXT) &&
+                        !pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.dstAddr == testSrcAddr
+            }
+            assertNotNull(reply)
+        } cleanup {
+            nsResponder?.stop()
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        }
+    }
+
     private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
         val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
         // Expect to have a /64 link-local address
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 9148770..361d68c 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -56,7 +56,6 @@
 import com.android.server.NetworkAgentWrapper
 import com.android.server.TestNetIdManager
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ConnectivityResources
 import com.android.server.connectivity.MockableSystemProperties
 import com.android.server.connectivity.MultinetworkPolicyTracker
@@ -89,6 +88,7 @@
 import org.mockito.MockitoAnnotations
 import org.mockito.Spy
 import java.util.function.Consumer
+import java.util.function.BiConsumer
 
 const val SERVICE_BIND_TIMEOUT_MS = 5_000L
 const val TEST_TIMEOUT_MS = 10_000L
@@ -245,7 +245,7 @@
             context: Context,
             tm: TelephonyManager,
             requestRestrictedWifiEnabled: Boolean,
-            listener: CarrierPrivilegesLostListener
+            listener: BiConsumer<Int, Int>
         ): CarrierPrivilegeAuthenticator {
             return CarrierPrivilegeAuthenticator(context,
                 object : CarrierPrivilegeAuthenticator.Dependencies() {
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 6623bbd..c534025 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -173,7 +173,6 @@
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.testutils.Cleanup.testAndCleanup;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
@@ -488,6 +487,7 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
@@ -526,7 +526,7 @@
     // between a LOST callback that arrives immediately and a LOST callback that arrives after
     // the linger/nascent timeout. For this, our assertions should run fast enough to leave
     // less than (mService.mLingerDelayMs - TEST_CALLBACK_TIMEOUT_MS) between the time callbacks are
-    // supposedly fired, and the time we call expectCallback.
+    // supposedly fired, and the time we call expectCapChanged.
     private static final int TEST_CALLBACK_TIMEOUT_MS = 250;
     // Chosen to be less than TEST_CALLBACK_TIMEOUT_MS. This ensures that requests have time to
     // complete before callbacks are verified.
@@ -565,6 +565,7 @@
     private static final int TEST_PACKAGE_UID2 = 321;
     private static final int TEST_PACKAGE_UID3 = 456;
     private static final int NETWORK_ACTIVITY_NO_UID = -1;
+    private static final int TEST_SUBSCRIPTION_ID = 1;
 
     private static final int PACKET_WAKEUP_MARK_MASK = 0x80000000;
 
@@ -2059,7 +2060,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 final boolean requestRestrictedWifiEnabled,
-                CarrierPrivilegesLostListener listener) {
+                BiConsumer<Integer, Integer> listener) {
             return mDeps.isAtLeastT() ? mCarrierPrivilegeAuthenticator : null;
         }
 
@@ -11486,7 +11487,7 @@
         doTestInterfaceClassActivityChanged(TRANSPORT_CELLULAR);
     }
 
-    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCallback)
+    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCapChanged)
             throws Exception {
         final ConditionVariable onNetworkActiveCv = new ConditionVariable();
         final ConnectivityManager.OnNetworkActiveListener listener = onNetworkActiveCv::open;
@@ -11498,7 +11499,7 @@
         testAndCleanup(() -> {
             mCm.addDefaultNetworkActiveListener(listener);
             agent.connect(true);
-            if (expectCallback) {
+            if (expectCapChanged) {
                 assertTrue(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
             } else {
                 assertFalse(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
@@ -11513,7 +11514,7 @@
 
     @Test
     public void testOnNetworkActive_NewCellConnects_CallbackCalled() throws Exception {
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCallback */);
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCapChanged */);
     }
 
     @Test
@@ -11522,8 +11523,8 @@
         // networks that tracker adds the idle timer to. And the tracker does not set the idle timer
         // for the ethernet network.
         // So onNetworkActive is not called when the ethernet becomes the default network
-        final boolean expectCallback = mDeps.isAtLeastV();
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCallback);
+        final boolean expectCapChanged = mDeps.isAtLeastV();
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCapChanged);
     }
 
     @Test
@@ -17375,7 +17376,7 @@
         return new NetworkRequest.Builder()
             .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
             .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
-            .setSubscriptionIds(Collections.singleton(Process.myUid()))
+            .setSubscriptionIds(Collections.singleton(TEST_SUBSCRIPTION_ID))
             .build();
     }
 
@@ -17422,32 +17423,46 @@
         final NetworkCallback networkCallback1 = new NetworkCallback();
         final NetworkCallback networkCallback2 = new NetworkCallback();
 
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), networkCallback1);
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), pendingIntent);
-        mCm.registerNetworkCallback(getRestrictedRequestForWifiWithSubIds(), networkCallback2);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback1);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), pendingIntent);
+        mCm.registerNetworkCallback(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback2);
 
         mCm.unregisterNetworkCallback(networkCallback1);
         mCm.releaseNetworkRequest(pendingIntent);
         mCm.unregisterNetworkCallback(networkCallback2);
     }
 
-    @Test
-    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
-    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
+    private void doTestNetworkRequestWithCarrierPrivilegesLost(
+            boolean shouldGrantRestrictedNetworkPermission,
+            int lostPrivilegeUid,
+            int lostPrivilegeSubId,
+            boolean expectUnavailable,
+            boolean expectCapChanged) throws Exception {
+        if (shouldGrantRestrictedNetworkPermission) {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
+        } else {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
+        }
+
+        NetworkCapabilities filter =
+                getRestrictedRequestForWifiWithSubIds().networkCapabilities;
         final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
         handlerThread.start();
+
         final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
                 mServiceContext, "testFactory", filter, mCsHandlerThread);
         testFactory.register();
-
         testFactory.assertRequestCountEquals(0);
+
         doReturn(true).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
+        final TestNetworkCallback networkCallback = new TestNetworkCallback();
+        final NetworkRequest networkrequest =
+                getRestrictedRequestForWifiWithSubIds();
+        mCm.requestNetwork(networkrequest, networkCallback);
         testFactory.expectRequestAdd();
         testFactory.assertRequestCountEquals(1);
 
@@ -17455,24 +17470,36 @@
                 .setAllowedUids(Set.of(Process.myUid()))
                 .build();
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, new LinkProperties(), nc);
-        mWiFiAgent.connect(true);
-        networkCallback1.expectAvailableThenValidatedCallbacks(mWiFiAgent);
-
+        mWiFiAgent.connect(false);
+        networkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         final NetworkAgentInfo nai = mService.getNetworkAgentInfoForNetwork(
                 mWiFiAgent.getNetwork());
 
         doReturn(false).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
+        doReturn(TEST_SUBSCRIPTION_ID).when(mCarrierPrivilegeAuthenticator)
+                .getSubIdFromNetworkCapabilities(any());
+        mService.onCarrierPrivilegesLost(lostPrivilegeUid, lostPrivilegeSubId);
         waitForIdle();
 
-        testFactory.expectRequestRemove();
-        testFactory.assertRequestCountEquals(0);
-        assertTrue(nai.networkCapabilities.getAllowedUidsNoCopy().isEmpty());
-        networkCallback1.expect(NETWORK_CAPS_UPDATED);
-        networkCallback1.expect(UNAVAILABLE);
+        if (expectCapChanged) {
+            networkCallback.expect(NETWORK_CAPS_UPDATED);
+        }
+        if (expectUnavailable) {
+            networkCallback.expect(UNAVAILABLE);
+        }
+        if (!expectCapChanged && !expectUnavailable) {
+            networkCallback.assertNoCallback();
+        }
+
+        mWiFiAgent.disconnect();
+        waitForIdle();
+
+        if (expectUnavailable) {
+            testFactory.assertRequestCountEquals(0);
+        } else {
+            testFactory.assertRequestCountEquals(1);
+        }
 
         handlerThread.quitSafely();
         handlerThread.join();
@@ -17480,64 +17507,45 @@
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                true /* expectUnavailable */,
+                true /* expectCapChanged */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRequestNotRemoved_MismatchSubId() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID + 1,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
+    }
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_MismatchUid() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid() + 1);
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid() + 1,
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
     }
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_HasRestrictedNetworkPermission() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-            .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                true /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                true /* expectCapChanged */);
     }
     @Test
     public void testAllowedUids() throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
index 9f0ec30..7bd2b56 100644
--- a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
@@ -20,7 +20,6 @@
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.telephony.TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED;
 
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.CARRIER_SERVICE_CHANGED_USE_CALLBACK;
 
 import static org.junit.Assert.assertEquals;
@@ -47,7 +46,6 @@
 import android.net.TelephonyNetworkSpecifier;
 import android.os.Build;
 import android.os.HandlerThread;
-import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 
 import com.android.net.module.util.CollectionUtils;
@@ -71,6 +69,7 @@
 import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiConsumer;
 
 /**
  * Tests for CarrierPrivilegeAuthenticatorTest.
@@ -92,7 +91,7 @@
     @NonNull private final TelephonyManagerShimImpl mTelephonyManagerShim;
     @NonNull private final PackageManager mPackageManager;
     @NonNull private TestCarrierPrivilegeAuthenticator mCarrierPrivilegeAuthenticator;
-    @NonNull private final CarrierPrivilegesLostListener mListener;
+    @NonNull private final BiConsumer<Integer, Integer> mListener;
     private final int mCarrierConfigPkgUid = 12345;
     private final boolean mUseCallbacks;
     private final String mTestPkg = "com.android.server.connectivity.test";
@@ -107,9 +106,8 @@
                     mListener);
         }
         @Override
-        protected int getSlotIndex(int subId) {
-            if (SubscriptionManager.DEFAULT_SUBSCRIPTION_ID == subId) return TEST_SUBSCRIPTION_ID;
-            return subId;
+        protected int getSubId(int slotIndex) {
+            return TEST_SUBSCRIPTION_ID;
         }
     }
 
@@ -129,7 +127,7 @@
         mTelephonyManager = mock(TelephonyManager.class);
         mTelephonyManagerShim = mock(TelephonyManagerShimImpl.class);
         mPackageManager = mock(PackageManager.class);
-        mListener = mock(CarrierPrivilegesLostListener.class);
+        mListener = mock(BiConsumer.class);
         mHandlerThread = new HandlerThread(CarrierPrivilegeAuthenticatorTest.class.getSimpleName());
         mUseCallbacks = useCallbacks;
         final Dependencies deps = mock(Dependencies.class);
@@ -184,7 +182,7 @@
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
-                .setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+                .setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
 
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
@@ -220,7 +218,8 @@
 
         newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -239,7 +238,11 @@
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
         if (mUseCallbacks) {
-            verify(mListener).onCarrierPrivilegesLost(eq(mCarrierConfigPkgUid));
+            verify(mListener).accept(eq(mCarrierConfigPkgUid), eq(TEST_SUBSCRIPTION_ID));
+        }
+        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+        if (mUseCallbacks) {
+            verify(mListener).accept(eq(mCarrierConfigPkgUid + 1), eq(TEST_SUBSCRIPTION_ID));
         }
     }
 
@@ -247,7 +250,8 @@
     public void testOnCarrierPrivilegesChanged() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -275,7 +279,7 @@
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
@@ -284,7 +288,7 @@
         ncBuilder.setNetworkSpecifier(null);
         ncBuilder.removeTransportType(TRANSPORT_CELLULAR);
         ncBuilder.addTransportType(TRANSPORT_WIFI);
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
@@ -298,7 +302,7 @@
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
         ncBuilder.removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED);
-        ncBuilder.setSubscriptionIds(Set.of(0));
+        ncBuilder.setSubscriptionIds(Set.of(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
diff --git a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
index 64a515a..193078b 100644
--- a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
@@ -21,9 +21,12 @@
 import android.content.Context
 import android.content.pm.ApplicationInfo
 import android.content.pm.PackageManager
+import android.content.pm.UserInfo
 import android.os.Build
 import android.os.Handler
 import android.os.UserHandle
+import android.util.ArraySet
+import com.android.server.makeMockUserManager
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import org.junit.Before
@@ -36,18 +39,31 @@
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
-import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 import java.util.concurrent.Executor
 import java.util.function.Consumer
-import kotlin.test.assertEquals
-import kotlin.test.assertFalse
-import kotlin.test.assertTrue
 
-private const val DEFAULT_MESSAGING_APP1 = "default_messaging_app_1"
-private const val DEFAULT_MESSAGING_APP2 = "default_messaging_app_2"
-private const val DEFAULT_MESSAGING_APP1_UID = 1234
-private const val DEFAULT_MESSAGING_APP2_UID = 5678
+private const val USER = 0
+val USER_INFO = UserInfo(USER, "" /* name */, UserInfo.FLAG_PRIMARY)
+val USER_HANDLE = UserHandle(USER)
+private const val PRIMARY_USER = 0
+private const val SECONDARY_USER = 10
+private val PRIMARY_USER_HANDLE = UserHandle.of(PRIMARY_USER)
+private val SECONDARY_USER_HANDLE = UserHandle.of(SECONDARY_USER)
+// sms app names
+private const val SMS_APP1 = "sms_app_1"
+private const val SMS_APP2 = "sms_app_2"
+// sms app ids
+private const val SMS_APP_ID1 = 100
+private const val SMS_APP_ID2 = 101
+// UID for app1 and app2 on primary user
+// These app could become default sms app for user1
+private val PRIMARY_USER_SMS_APP_UID1 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID1)
+private val PRIMARY_USER_SMS_APP_UID2 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID2)
+// UID for app1 and app2 on secondary user
+// These app could become default sms app for user2
+private val SECONDARY_USER_SMS_APP_UID1 = UserHandle.getUid(SECONDARY_USER, SMS_APP_ID1)
+private val SECONDARY_USER_SMS_APP_UID2 = UserHandle.getUid(SECONDARY_USER, SMS_APP_ID2)
 
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
@@ -58,33 +74,36 @@
     private val mRoleManager =
         mock(SatelliteAccessController.Dependencies::class.java)
     private val mCallback = mock(Consumer::class.java) as Consumer<Set<Int>>
-    private val mSatelliteAccessController by lazy {
-        SatelliteAccessController(context, mRoleManager, mCallback, mHandler)}
-    private var mRoleHolderChangedListener: OnRoleHoldersChangedListener? = null
+    private val mSatelliteAccessController =
+        SatelliteAccessController(context, mRoleManager, mCallback, mHandler)
+    private lateinit var mRoleHolderChangedListener: OnRoleHoldersChangedListener
     @Before
     @Throws(PackageManager.NameNotFoundException::class)
     fun setup() {
+        makeMockUserManager(USER_INFO, USER_HANDLE)
+        doReturn(context).`when`(context).createContextAsUser(any(), anyInt())
         doReturn(mPackageManager).`when`(context).packageManager
-        doReturn(PackageManager.PERMISSION_GRANTED)
-            .`when`(mPackageManager)
-            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, DEFAULT_MESSAGING_APP1)
-        doReturn(PackageManager.PERMISSION_GRANTED)
-            .`when`(mPackageManager)
-            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, DEFAULT_MESSAGING_APP2)
 
-        // Initialise default message application package1
+        doReturn(PackageManager.PERMISSION_GRANTED)
+            .`when`(mPackageManager)
+            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
+        doReturn(PackageManager.PERMISSION_GRANTED)
+            .`when`(mPackageManager)
+            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP2)
+
+        // Initialise default message application primary user package1
         val applicationInfo1 = ApplicationInfo()
-        applicationInfo1.uid = DEFAULT_MESSAGING_APP1_UID
+        applicationInfo1.uid = PRIMARY_USER_SMS_APP_UID1
         doReturn(applicationInfo1)
             .`when`(mPackageManager)
-            .getApplicationInfo(eq(DEFAULT_MESSAGING_APP1), anyInt())
+            .getApplicationInfo(eq(SMS_APP1), anyInt())
 
-        // Initialise default message application package2
+        // Initialise default message application primary user package2
         val applicationInfo2 = ApplicationInfo()
-        applicationInfo2.uid = DEFAULT_MESSAGING_APP2_UID
+        applicationInfo2.uid = PRIMARY_USER_SMS_APP_UID2
         doReturn(applicationInfo2)
             .`when`(mPackageManager)
-            .getApplicationInfo(eq(DEFAULT_MESSAGING_APP2), anyInt())
+            .getApplicationInfo(eq(SMS_APP2), anyInt())
 
         // Get registered listener using captor
         val listenerCaptor = ArgumentCaptor.forClass(
@@ -97,80 +116,107 @@
     }
 
     @Test
-    fun test_onRoleHoldersChanged_SatellitePreferredUid_Changed() {
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        val satelliteNetworkPreferredSet =
-            ArgumentCaptor.forClass(Set::class.java) as ArgumentCaptor<Set<Int>>
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, never()).accept(satelliteNetworkPreferredSet.capture())
+    fun test_onRoleHoldersChanged_SatelliteFallbackUid_Changed_SingleUser() {
+        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback, never()).accept(any())
 
-        // check DEFAULT_MESSAGING_APP1 is available as satellite network preferred uid
-        doReturn(listOf(DEFAULT_MESSAGING_APP1))
-            .`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback).accept(satelliteNetworkPreferredSet.capture())
-        var satelliteNetworkPreferredUids = satelliteNetworkPreferredSet.value
-        assertEquals(1, satelliteNetworkPreferredUids.size)
-        assertTrue(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP1_UID))
-        assertFalse(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP2_UID))
+        // check DEFAULT_MESSAGING_APP1 is available as satellite network fallback uid
+        doReturn(listOf(SMS_APP1))
+            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
 
-        // check DEFAULT_MESSAGING_APP1 and DEFAULT_MESSAGING_APP2 is available
-        // as satellite network preferred uid
-        val dmas: MutableList<String> = ArrayList()
-        dmas.add(DEFAULT_MESSAGING_APP1)
-        dmas.add(DEFAULT_MESSAGING_APP2)
-        doReturn(dmas).`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, times(2))
-            .accept(satelliteNetworkPreferredSet.capture())
-        satelliteNetworkPreferredUids = satelliteNetworkPreferredSet.value
-        assertEquals(2, satelliteNetworkPreferredUids.size)
-        assertTrue(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP1_UID))
-        assertTrue(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP2_UID))
+        // check SMS_APP2 is available as satellite network Fallback uid
+        doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
 
-        // check no uid is available as satellite network preferred uid
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, times(3))
-            .accept(satelliteNetworkPreferredSet.capture())
-        satelliteNetworkPreferredUids = satelliteNetworkPreferredSet.value
-        assertEquals(0, satelliteNetworkPreferredUids.size)
-        assertFalse(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP1_UID))
-        assertFalse(satelliteNetworkPreferredUids.contains(DEFAULT_MESSAGING_APP2_UID))
-
-        // No Change received at OnRoleSmsChanged, check callback not triggered
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, times(3))
-            .accept(satelliteNetworkPreferredSet.capture())
+        // check no uid is available as satellite network fallback uid
+        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(ArraySet())
     }
 
     @Test
     fun test_onRoleHoldersChanged_NoSatelliteCommunicationPermission() {
-        doReturn(listOf<Any>()).`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        val satelliteNetworkPreferredSet =
-            ArgumentCaptor.forClass(Set::class.java) as ArgumentCaptor<Set<Int>>
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, never()).accept(satelliteNetworkPreferredSet.capture())
+        doReturn(listOf<Any>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback, never()).accept(any())
 
-        // check DEFAULT_MESSAGING_APP1 is not available as satellite network preferred uid
+        // check DEFAULT_MESSAGING_APP1 is not available as satellite network fallback uid
         // since satellite communication permission not available.
         doReturn(PackageManager.PERMISSION_DENIED)
             .`when`(mPackageManager)
-            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, DEFAULT_MESSAGING_APP1)
-        doReturn(listOf(DEFAULT_MESSAGING_APP1))
-            .`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_SMS, UserHandle.ALL)
-        verify(mCallback, never()).accept(satelliteNetworkPreferredSet.capture())
+            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
+        doReturn(listOf(SMS_APP1))
+            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback, never()).accept(any())
     }
 
     @Test
     fun test_onRoleHoldersChanged_RoleSms_NotAvailable() {
-        doReturn(listOf(DEFAULT_MESSAGING_APP1))
-            .`when`(mRoleManager).getRoleHolders(RoleManager.ROLE_SMS)
-        val satelliteNetworkPreferredSet =
-            ArgumentCaptor.forClass(Set::class.java) as ArgumentCaptor<Set<Int>>
-        mRoleHolderChangedListener?.onRoleHoldersChanged(RoleManager.ROLE_BROWSER, UserHandle.ALL)
-        verify(mCallback, never()).accept(satelliteNetworkPreferredSet.capture())
+        doReturn(listOf(SMS_APP1))
+            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_BROWSER,
+            PRIMARY_USER_HANDLE)
+        verify(mCallback, never()).accept(any())
+    }
+
+    @Test
+    fun test_onRoleHoldersChanged_SatelliteNetworkFallbackUid_Changed_multiUser() {
+        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback, never()).accept(any())
+
+        // check SMS_APP1 is available as satellite network fallback uid at primary user
+        doReturn(listOf(SMS_APP1))
+            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
+
+        // check SMS_APP2 is available as satellite network fallback uid at primary user
+        doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
+
+        // check SMS_APP1 is available as satellite network fallback uid at secondary user
+        val applicationInfo1 = ApplicationInfo()
+        applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID1
+        doReturn(applicationInfo1).`when`(mPackageManager)
+            .getApplicationInfo(eq(SMS_APP1), anyInt())
+        doReturn(listOf(SMS_APP1)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            SECONDARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2, SECONDARY_USER_SMS_APP_UID1))
+
+        // check no uid is available as satellite network fallback uid at primary user
+        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID1))
+
+        // check SMS_APP2 is available as satellite network fallback uid at secondary user
+        applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID2
+        doReturn(applicationInfo1).`when`(mPackageManager)
+            .getApplicationInfo(eq(SMS_APP2), anyInt())
+        doReturn(listOf(SMS_APP2))
+            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID2))
+
+        // check no uid is available as satellite network fallback uid at secondary user
+        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
+            SECONDARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+        verify(mCallback).accept(ArraySet())
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSCaptivePortalAppTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSCaptivePortalAppTest.kt
new file mode 100644
index 0000000..be2b29c
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSCaptivePortalAppTest.kt
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server
+
+import android.Manifest.permission.NETWORK_STACK
+import android.content.Intent
+import android.content.pm.PackageManager.PERMISSION_DENIED
+import android.content.pm.PackageManager.PERMISSION_GRANTED
+import android.net.ConnectivityManager.ACTION_CAPTIVE_PORTAL_SIGN_IN
+import android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL
+import android.net.IpPrefix
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_CAPTIVE_PORTAL
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
+import android.net.NetworkStack
+import android.net.CaptivePortal
+import android.net.NetworkRequest
+import android.net.NetworkScore
+import android.net.NetworkScore.KEEP_CONNECTED_FOR_TEST
+import android.net.RouteInfo
+import android.os.Build
+import android.os.Bundle
+import androidx.test.filters.SmallTest
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.assertThrows
+import com.android.testutils.TestableNetworkCallback
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.verify
+import kotlin.test.assertEquals
+
+// This allows keeping all the networks connected without having to file individual requests
+// for them.
+private fun keepScore() = FromS(
+        NetworkScore.Builder().setKeepConnectedReason(KEEP_CONNECTED_FOR_TEST).build()
+)
+
+private fun nc(transport: Int, vararg caps: Int) = NetworkCapabilities.Builder().apply {
+    addTransportType(transport)
+    caps.forEach {
+        addCapability(it)
+    }
+    // Useful capabilities for everybody
+    addCapability(NET_CAPABILITY_NOT_RESTRICTED)
+    addCapability(NET_CAPABILITY_NOT_SUSPENDED)
+    addCapability(NET_CAPABILITY_NOT_ROAMING)
+    addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+}.build()
+
+private fun lp(iface: String) = LinkProperties().apply {
+    interfaceName = iface
+    addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 32))
+    addRoute(RouteInfo(IpPrefix("0.0.0.0/0"), null, null))
+}
+
+@DevSdkIgnoreRunner.MonitorThreadLeak
+@RunWith(DevSdkIgnoreRunner::class)
+@SmallTest
+@IgnoreUpTo(Build.VERSION_CODES.R)
+class CSCaptivePortalAppTest : CSTest() {
+    private val WIFI_IFACE = "wifi0"
+    private val TEST_REDIRECT_URL = "http://example.com/firstPath"
+    private val TIMEOUT_MS = 2_000L
+
+    @Test
+    fun testCaptivePortalApp_Reevaluate_Nopermission() {
+        val captivePortalCallback = TestableNetworkCallback()
+        val captivePortalRequest = NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_CAPTIVE_PORTAL).build()
+        cm.registerNetworkCallback(captivePortalRequest, captivePortalCallback)
+        val wifiAgent = createWifiAgent()
+        wifiAgent.connectWithCaptivePortal(TEST_REDIRECT_URL)
+        captivePortalCallback.expectAvailableCallbacksUnvalidated(wifiAgent)
+        val signInIntent = startCaptivePortalApp(wifiAgent)
+        // Remove the granted permissions
+        context.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
+                PERMISSION_DENIED)
+        context.setPermission(NETWORK_STACK, PERMISSION_DENIED)
+        val captivePortal: CaptivePortal? = signInIntent.getParcelableExtra(EXTRA_CAPTIVE_PORTAL)
+        assertThrows(SecurityException::class.java, { captivePortal?.reevaluateNetwork() })
+    }
+
+    private fun createWifiAgent(): CSAgentWrapper {
+        return Agent(score = keepScore(), lp = lp(WIFI_IFACE),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+    }
+
+    private fun startCaptivePortalApp(networkAgent: CSAgentWrapper): Intent {
+        val network = networkAgent.network
+        cm.startCaptivePortalApp(network)
+        waitForIdle()
+        verify(networkAgent.networkMonitor).launchCaptivePortalApp()
+
+        val testBundle = Bundle()
+        val testKey = "testkey"
+        val testValue = "testvalue"
+        testBundle.putString(testKey, testValue)
+        context.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, PERMISSION_GRANTED)
+        cm.startCaptivePortalApp(network, testBundle)
+        val signInIntent: Intent = context.expectStartActivityIntent(TIMEOUT_MS)
+        assertEquals(ACTION_CAPTIVE_PORTAL_SIGN_IN, signInIntent.getAction())
+        assertEquals(testValue, signInIntent.getStringExtra(testKey))
+        return signInIntent
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
index d41c742..d7343b1 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
@@ -19,6 +19,8 @@
 import android.content.Context
 import android.net.ConnectivityManager
 import android.net.INetworkMonitor
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTP
 import android.net.INetworkMonitorCallbacks
 import android.net.LinkProperties
 import android.net.LocalNetworkConfig
@@ -75,10 +77,15 @@
 ) : TestableNetworkCallback.HasNetwork {
     private val TAG = "CSAgent${nextAgentId()}"
     private val VALIDATION_RESULT_INVALID = 0
+    private val NO_PROBE_RESULT = 0
     private val VALIDATION_TIMESTAMP = 1234L
     private val agent: NetworkAgent
     private val nmCallbacks: INetworkMonitorCallbacks
     val networkMonitor = mock<INetworkMonitor>()
+    private var nmValidationRedirectUrl: String? = null
+    private var nmValidationResult = NO_PROBE_RESULT
+    private var nmProbesCompleted = NO_PROBE_RESULT
+    private var nmProbesSucceeded = NO_PROBE_RESULT
 
     override val network: Network get() = agent.network!!
 
@@ -120,10 +127,10 @@
         }
         nmCallbacks.notifyProbeStatusChanged(0 /* completed */, 0 /* succeeded */)
         val p = NetworkTestResultParcelable()
-        p.result = VALIDATION_RESULT_INVALID
-        p.probesAttempted = 0
-        p.probesSucceeded = 0
-        p.redirectUrl = null
+        p.result = nmValidationResult
+        p.probesAttempted = nmProbesCompleted
+        p.probesSucceeded = nmProbesSucceeded
+        p.redirectUrl = nmValidationRedirectUrl
         p.timestampMillis = VALIDATION_TIMESTAMP
         nmCallbacks.notifyNetworkTestedWithExtras(p)
     }
@@ -171,4 +178,26 @@
 
     fun sendLocalNetworkConfig(lnc: LocalNetworkConfig) = agent.sendLocalNetworkConfig(lnc)
     fun sendNetworkCapabilities(nc: NetworkCapabilities) = agent.sendNetworkCapabilities(nc)
+
+    fun connectWithCaptivePortal(redirectUrl: String) {
+        setCaptivePortal(redirectUrl)
+        connect()
+    }
+
+    fun setProbesStatus(probesCompleted: Int, probesSucceeded: Int) {
+        nmProbesCompleted = probesCompleted
+        nmProbesSucceeded = probesSucceeded
+    }
+
+    fun setCaptivePortal(redirectUrl: String) {
+        nmValidationResult = VALIDATION_RESULT_INVALID
+        nmValidationRedirectUrl = redirectUrl
+        // Suppose the portal is found when NetworkMonitor probes NETWORK_VALIDATION_PROBE_HTTP
+        // in the beginning. Because NETWORK_VALIDATION_PROBE_HTTP is the decisive probe for captive
+        // portal, considering the NETWORK_VALIDATION_PROBE_HTTPS hasn't probed yet and set only
+        // DNS and HTTP probes completed.
+        setProbesStatus(
+            NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP /* probesCompleted */,
+            VALIDATION_RESULT_INVALID /* probesSucceeded */)
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index b15c684..595ca47 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -47,8 +47,10 @@
 import android.os.Bundle
 import android.os.Handler
 import android.os.HandlerThread
+import android.os.Process
 import android.os.UserHandle
 import android.os.UserManager
+import android.permission.PermissionManager.PermissionResult
 import android.telephony.TelephonyManager
 import android.testing.TestableContext
 import android.util.ArraySet
@@ -60,7 +62,6 @@
 import com.android.networkstack.apishim.common.UnsupportedApiLevelException
 import com.android.server.connectivity.AutomaticOnOffKeepaliveTracker
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ClatCoordinator
 import com.android.server.connectivity.ConnectivityFlags
 import com.android.server.connectivity.MulticastRoutingCoordinatorService
@@ -72,7 +73,11 @@
 import com.android.testutils.visibleOnHandlerThread
 import com.android.testutils.waitForIdle
 import java.util.concurrent.Executors
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.TimeUnit
 import java.util.function.Consumer
+import java.util.function.BiConsumer
+import kotlin.test.assertNotNull
 import kotlin.test.assertNull
 import kotlin.test.fail
 import org.junit.After
@@ -82,7 +87,7 @@
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 
-internal const val HANDLER_TIMEOUT_MS = 2_000
+internal const val HANDLER_TIMEOUT_MS = 2_000L
 internal const val BROADCAST_TIMEOUT_MS = 3_000L
 internal const val TEST_PACKAGE_NAME = "com.android.test.package"
 internal const val WIFI_WOL_IFNAME = "test_wlan_wol"
@@ -222,7 +227,7 @@
                 context: Context,
                 tm: TelephonyManager,
                 requestRestrictedWifiEnabled: Boolean,
-                listener: CarrierPrivilegesLostListener
+                listener: BiConsumer<Int, Int>
         ) = if (SdkLevel.isAtLeastT()) mock<CarrierPrivilegeAuthenticator>() else null
 
         var satelliteNetworkFallbackUidUpdate: Consumer<Set<Int>>? = null
@@ -300,13 +305,65 @@
         val pacProxyManager = mock<PacProxyManager>()
         val networkPolicyManager = mock<NetworkPolicyManager>()
 
+        // Map of permission name -> PermissionManager.Permission_{GRANTED|DENIED} constant
+        // For permissions granted across the board, the key is only the permission name.
+        // For permissions only granted to a combination of uid/pid, the key
+        // is "<permission name>,<pid>,<uid>". PID+UID permissions have priority over generic ones.
+        private val mMockedPermissions: HashMap<String, Int> = HashMap()
+        private val mStartedActivities = LinkedBlockingQueue<Intent>()
         override fun getPackageManager() = this@CSTest.packageManager
         override fun getContentResolver() = this@CSTest.contentResolver
 
-        // TODO : buff up the capabilities of this permission scheme to allow checking for
-        // permission rejections
-        override fun checkPermission(permission: String, pid: Int, uid: Int) = PERMISSION_GRANTED
-        override fun checkCallingOrSelfPermission(permission: String) = PERMISSION_GRANTED
+        // If the permission result does not set in the mMockedPermissions, it will be
+        // considered as PERMISSION_GRANTED as existing design to prevent breaking other tests.
+        override fun checkPermission(permission: String, pid: Int, uid: Int) =
+            checkMockedPermission(permission, pid, uid, PERMISSION_GRANTED)
+
+        override fun enforceCallingOrSelfPermission(permission: String, message: String?) {
+            // If the permission result does not set in the mMockedPermissions, it will be
+            // considered as PERMISSION_GRANTED as existing design to prevent breaking other tests.
+            val granted = checkMockedPermission(permission, Process.myPid(), Process.myUid(),
+                PERMISSION_GRANTED)
+            if (!granted.equals(PERMISSION_GRANTED)) {
+                throw SecurityException("[Test] permission denied: " + permission)
+            }
+        }
+
+        // If the permission result does not set in the mMockedPermissions, it will be
+        // considered as PERMISSION_GRANTED as existing design to prevent breaking other tests.
+        override fun checkCallingOrSelfPermission(permission: String) =
+            checkMockedPermission(permission, Process.myPid(), Process.myUid(), PERMISSION_GRANTED)
+
+        private fun checkMockedPermission(permission: String, pid: Int, uid: Int, default: Int):
+                Int {
+            val processSpecificKey = "$permission,$pid,$uid"
+            return mMockedPermissions[processSpecificKey]
+                    ?: mMockedPermissions[permission] ?: default
+        }
+
+        /**
+         * Mock checks for the specified permission, and have them behave as per `granted` or
+         * `denied`.
+         *
+         * This will apply to all calls no matter what the checked UID and PID are.
+         *
+         * @param granted One of {@link PackageManager#PermissionResult}.
+         */
+        fun setPermission(permission: String, @PermissionResult granted: Int) {
+            mMockedPermissions.put(permission, granted)
+        }
+
+        /**
+         * Mock checks for the specified permission, and have them behave as per `granted` or
+         * `denied`.
+         *
+         * This will only apply to the passed UID and PID.
+         *
+         * @param granted One of {@link PackageManager#PermissionResult}.
+         */
+        fun setPermission(permission: String, pid: Int, uid: Int, @PermissionResult granted: Int) {
+            mMockedPermissions.put("$permission,$pid,$uid", granted)
+        }
 
         // Necessary for MultinetworkPolicyTracker, which tries to register a receiver for
         // all users. The test can't do that since it doesn't hold INTERACT_ACROSS_USERS.
@@ -364,6 +421,16 @@
         ) {
             orderedBroadcastAsUserHistory.add(intent)
         }
+
+        override fun startActivityAsUser(intent: Intent, handle: UserHandle) {
+            mStartedActivities.put(intent)
+        }
+
+        fun expectStartActivityIntent(timeoutMs: Long = HANDLER_TIMEOUT_MS): Intent {
+            val intent = mStartedActivities.poll(timeoutMs, TimeUnit.MILLISECONDS)
+            assertNotNull(intent, "Did not receive sign-in intent after " + timeoutMs + "ms")
+            return intent
+        }
     }
 
     // Utility methods for subclasses to use
diff --git a/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt b/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt
new file mode 100644
index 0000000..27e6f96
--- /dev/null
+++ b/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net
+
+import android.net.NetworkStats
+import com.android.testutils.DevSdkIgnoreRunner
+import java.time.Clock
+import kotlin.test.assertEquals
+import kotlin.test.assertNull
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.`when`
+
+@RunWith(DevSdkIgnoreRunner::class)
+class TrafficStatsRateLimitCacheTest {
+    companion object {
+        private const val expiryDurationMs = 1000L
+    }
+
+    private val clock = mock(Clock::class.java)
+    private val entry = mock(NetworkStats.Entry::class.java)
+    private val cache = TrafficStatsRateLimitCache(clock, expiryDurationMs)
+
+    @Test
+    fun testGet_returnsEntryIfNotExpired() {
+        cache.put("iface", 2, entry)
+        `when`(clock.millis()).thenReturn(500L) // Set clock to before expiry
+        val result = cache.get("iface", 2)
+        assertEquals(entry, result)
+    }
+
+    @Test
+    fun testGet_returnsNullIfExpired() {
+        cache.put("iface", 2, entry)
+        `when`(clock.millis()).thenReturn(2000L) // Set clock to after expiry
+        assertNull(cache.get("iface", 2))
+    }
+
+    @Test
+    fun testGet_returnsNullForNonExistentKey() {
+        val result = cache.get("otherIface", 99)
+        assertNull(result)
+    }
+
+    @Test
+    fun testPutAndGet_retrievesCorrectEntryForDifferentKeys() {
+        val entry1 = mock(NetworkStats.Entry::class.java)
+        val entry2 = mock(NetworkStats.Entry::class.java)
+
+        cache.put("iface1", 2, entry1)
+        cache.put("iface2", 4, entry2)
+
+        assertEquals(entry1, cache.get("iface1", 2))
+        assertEquals(entry2, cache.get("iface2", 4))
+    }
+
+    @Test
+    fun testPut_overridesExistingEntry() {
+        val entry1 = mock(NetworkStats.Entry::class.java)
+        val entry2 = mock(NetworkStats.Entry::class.java)
+
+        cache.put("iface", 2, entry1)
+        cache.put("iface", 2, entry2) // Put with the same key
+
+        assertEquals(entry2, cache.get("iface", 2))
+    }
+
+    @Test
+    fun testClear() {
+        cache.put("iface", 2, entry)
+        cache.clear()
+        assertNull(cache.get("iface", 2))
+    }
+}
diff --git a/thread/flags/Android.bp b/thread/flags/Android.bp
new file mode 100644
index 0000000..15f58a9
--- /dev/null
+++ b/thread/flags/Android.bp
@@ -0,0 +1,23 @@
+//
+// Copyright (C) 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+aconfig_declarations {
+    name: "com.android.net.thread.flags-aconfig",
+    package: "com.android.net.thread.flags",
+    container: "system",
+    srcs: ["thread_base.aconfig"],
+    visibility: ["//packages/modules/Connectivity:__subpackages__"],
+}
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkException.java b/thread/framework/java/android/net/thread/ThreadNetworkException.java
index 66f13ce..4def0fb 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkException.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkException.java
@@ -89,8 +89,9 @@
 
     /**
      * The operation failed because required preconditions were not satisfied. For example, trying
-     * to schedule a network migration when this device is not attached will receive this error. The
-     * caller should not retry the same operation before the precondition is satisfied.
+     * to schedule a network migration when this device is not attached will receive this error or
+     * enable Thread while User Resitration has disabled it. The caller should not retry the same
+     * operation before the precondition is satisfied.
      */
     public static final int ERROR_FAILED_PRECONDITION = 6;
 
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkManager.java b/thread/framework/java/android/net/thread/ThreadNetworkManager.java
index 28012a7..150b759 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkManager.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkManager.java
@@ -79,6 +79,17 @@
     public static final String PERMISSION_THREAD_NETWORK_PRIVILEGED =
             "android.permission.THREAD_NETWORK_PRIVILEGED";
 
+    /**
+     * This user restriction specifies if Thread network is disallowed on the device. If Thread
+     * network is disallowed it cannot be turned on via Settings.
+     *
+     * <p>this is a mirror of {@link UserManager#DISALLOW_THREAD_NETWORK} which is not available on
+     * Android U devices.
+     *
+     * @hide
+     */
+    public static final String DISALLOW_THREAD_NETWORK = "no_thread_network";
+
     @NonNull private final Context mContext;
     @NonNull private final List<ThreadNetworkController> mUnmodifiableControllerServices;
 
diff --git a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
new file mode 100644
index 0000000..e3b4e1a
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
+
+import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.IActiveOperationalDatasetReceiver;
+import android.os.RemoteException;
+
+import com.android.internal.annotations.GuardedBy;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * A {@link IActiveOperationalDatasetReceiver} wrapper which makes it easier to invoke the
+ * callbacks.
+ */
+final class ActiveOperationalDatasetReceiverWrapper {
+    private final IActiveOperationalDatasetReceiver mReceiver;
+
+    private static final Object sPendingReceiversLock = new Object();
+
+    @GuardedBy("sPendingReceiversLock")
+    private static final Set<ActiveOperationalDatasetReceiverWrapper> sPendingReceivers =
+            new HashSet<>();
+
+    public ActiveOperationalDatasetReceiverWrapper(IActiveOperationalDatasetReceiver receiver) {
+        this.mReceiver = receiver;
+
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.add(this);
+        }
+    }
+
+    public static void onOtDaemonDied() {
+        synchronized (sPendingReceiversLock) {
+            for (ActiveOperationalDatasetReceiverWrapper receiver : sPendingReceivers) {
+                try {
+                    receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                } catch (RemoteException e) {
+                    // The client is dead, do nothing
+                }
+            }
+            sPendingReceivers.clear();
+        }
+    }
+
+    public void onSuccess(ActiveOperationalDataset dataset) {
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.remove(this);
+        }
+
+        try {
+            mReceiver.onSuccess(dataset);
+        } catch (RemoteException e) {
+            // The client is dead, do nothing
+        }
+    }
+
+    public void onError(int errorCode, String errorMessage) {
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.remove(this);
+        }
+
+        try {
+            mReceiver.onError(errorCode, errorMessage);
+        } catch (RemoteException e) {
+            // The client is dead, do nothing
+        }
+    }
+}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 21e3927..56dd056 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -41,11 +41,12 @@
 import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_TIMEOUT;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_ABORT;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_BUSY;
-import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_DETACHED;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_FAILED_PRECONDITION;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_NO_BUFS;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_PARSE;
@@ -64,7 +65,10 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.annotation.TargetApi;
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
 import android.net.ConnectivityManager;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
@@ -98,12 +102,14 @@
 import android.os.Looper;
 import android.os.RemoteException;
 import android.os.SystemClock;
+import android.os.UserManager;
 import android.util.Log;
 import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.ServiceManagerWrapper;
 import com.android.server.thread.openthread.BorderRouterConfigurationParcel;
+import com.android.server.thread.openthread.IChannelMasksReceiver;
 import com.android.server.thread.openthread.IOtDaemon;
 import com.android.server.thread.openthread.IOtDaemonCallback;
 import com.android.server.thread.openthread.IOtStatusReceiver;
@@ -152,9 +158,6 @@
     private final NsdPublisher mNsdPublisher;
     private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
 
-    // TODO(b/308310823): read supported channel from Thread dameon
-    private final int mSupportedChannelMask = 0x07FFF800; // from channel 11 to 26
-
     @Nullable private IOtDaemon mOtDaemon;
     @Nullable private NetworkAgent mNetworkAgent;
     @Nullable private NetworkAgent mTestNetworkAgent;
@@ -167,6 +170,8 @@
     private TestNetworkSpecifier mUpstreamTestNetworkSpecifier;
     private final HashMap<Network, String> mNetworkToInterface;
     private final ThreadPersistentSettings mPersistentSettings;
+    private final UserManager mUserManager;
+    private boolean mUserRestricted;
 
     private BorderRouterConfigurationParcel mBorderRouterConfig;
 
@@ -180,7 +185,8 @@
             TunInterfaceController tunIfController,
             InfraInterfaceController infraIfController,
             ThreadPersistentSettings persistentSettings,
-            NsdPublisher nsdPublisher) {
+            NsdPublisher nsdPublisher,
+            UserManager userManager) {
         mContext = context;
         mHandler = handler;
         mNetworkProvider = networkProvider;
@@ -193,6 +199,7 @@
         mBorderRouterConfig = new BorderRouterConfigurationParcel();
         mPersistentSettings = persistentSettings;
         mNsdPublisher = nsdPublisher;
+        mUserManager = userManager;
     }
 
     public static ThreadNetworkControllerService newInstance(
@@ -212,7 +219,8 @@
                 new TunInterfaceController(TUN_IF_NAME),
                 new InfraInterfaceController(),
                 persistentSettings,
-                NsdPublisher.newInstance(context, handler));
+                NsdPublisher.newInstance(context, handler),
+                context.getSystemService(UserManager.class));
     }
 
     private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -288,10 +296,7 @@
         if (otDaemon == null) {
             throw new RemoteException("Internal error: failed to start OT daemon");
         }
-        otDaemon.initialize(
-                mTunIfController.getTunFd(),
-                mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED),
-                mNsdPublisher);
+        otDaemon.initialize(mTunIfController.getTunFd(), isEnabled(), mNsdPublisher);
         otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
@@ -323,23 +328,39 @@
                     mConnectivityManager.registerNetworkProvider(mNetworkProvider);
                     requestUpstreamNetwork();
                     requestThreadNetwork();
-
+                    mUserRestricted = isThreadUserRestricted();
+                    registerUserRestrictionsReceiver();
                     initializeOtDaemon();
                 });
     }
 
-    public void setEnabled(@NonNull boolean isEnabled, @NonNull IOperationReceiver receiver) {
+    public void setEnabled(boolean isEnabled, @NonNull IOperationReceiver receiver) {
         enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
-        mHandler.post(() -> setEnabledInternal(isEnabled, new OperationReceiverWrapper(receiver)));
+        mHandler.post(
+                () ->
+                        setEnabledInternal(
+                                isEnabled,
+                                true /* persist */,
+                                new OperationReceiverWrapper(receiver)));
     }
 
     private void setEnabledInternal(
-            @NonNull boolean isEnabled, @Nullable OperationReceiverWrapper receiver) {
-        // The persistent setting keeps the desired enabled state, thus it's set regardless
-        // the otDaemon set enabled state operation succeeded or not, so that it can recover
-        // to the desired value after reboot.
-        mPersistentSettings.put(ThreadPersistentSettings.THREAD_ENABLED.key, isEnabled);
+            boolean isEnabled, boolean persist, @NonNull OperationReceiverWrapper receiver) {
+        if (isEnabled && isThreadUserRestricted()) {
+            receiver.onError(
+                    ERROR_FAILED_PRECONDITION,
+                    "Cannot enable Thread: forbidden by user restriction");
+            return;
+        }
+
+        if (persist) {
+            // The persistent setting keeps the desired enabled state, thus it's set regardless
+            // the otDaemon set enabled state operation succeeded or not, so that it can recover
+            // to the desired value after reboot.
+            mPersistentSettings.put(ThreadPersistentSettings.THREAD_ENABLED.key, isEnabled);
+        }
+
         try {
             getOtDaemon().setThreadEnabled(isEnabled, newOtStatusReceiver(receiver));
         } catch (RemoteException e) {
@@ -348,6 +369,67 @@
         }
     }
 
+    private void registerUserRestrictionsReceiver() {
+        mContext.registerReceiver(
+                new BroadcastReceiver() {
+                    @Override
+                    public void onReceive(Context context, Intent intent) {
+                        onUserRestrictionsChanged(isThreadUserRestricted());
+                    }
+                },
+                new IntentFilter(UserManager.ACTION_USER_RESTRICTIONS_CHANGED),
+                null /* broadcastPermission */,
+                mHandler);
+    }
+
+    private void onUserRestrictionsChanged(boolean newUserRestrictedState) {
+        checkOnHandlerThread();
+        if (mUserRestricted == newUserRestrictedState) {
+            return;
+        }
+        Log.i(
+                TAG,
+                "Thread user restriction changed: "
+                        + mUserRestricted
+                        + " -> "
+                        + newUserRestrictedState);
+        mUserRestricted = newUserRestrictedState;
+
+        final boolean isEnabled = isEnabled();
+        final IOperationReceiver receiver =
+                new IOperationReceiver.Stub() {
+                    @Override
+                    public void onSuccess() {
+                        Log.d(
+                                TAG,
+                                (isEnabled ? "Enabled" : "Disabled")
+                                        + " Thread due to user restriction change");
+                    }
+
+                    @Override
+                    public void onError(int otError, String messages) {
+                        Log.e(
+                                TAG,
+                                "Failed to "
+                                        + (isEnabled ? "enable" : "disable")
+                                        + " Thread for user restriction change");
+                    }
+                };
+        // Do not save the user restriction state to persistent settings so that the user
+        // configuration won't be overwritten
+        setEnabledInternal(isEnabled, false /* persist */, new OperationReceiverWrapper(receiver));
+    }
+
+    /** Returns {@code true} if Thread is set enabled. */
+    private boolean isEnabled() {
+        return !mUserRestricted && mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED);
+    }
+
+    /** Returns {@code true} if Thread has been restricted for the user. */
+    private boolean isThreadUserRestricted() {
+        return mUserManager.hasUserRestriction(DISALLOW_THREAD_NETWORK);
+    }
+
     private void requestUpstreamNetwork() {
         if (mUpstreamNetworkCallback != null) {
             throw new AssertionError("The upstream network request is already there.");
@@ -509,26 +591,51 @@
     @Override
     public void createRandomizedDataset(
             String networkName, IActiveOperationalDatasetReceiver receiver) {
-        mHandler.post(
-                () -> {
-                    ActiveOperationalDataset dataset =
-                            createRandomizedDatasetInternal(
-                                    networkName,
-                                    mSupportedChannelMask,
-                                    Instant.now(),
-                                    new Random(),
-                                    new SecureRandom());
-                    try {
-                        receiver.onSuccess(dataset);
-                    } catch (RemoteException e) {
-                        // The client is dead, do nothing
-                    }
-                });
+        ActiveOperationalDatasetReceiverWrapper receiverWrapper =
+                new ActiveOperationalDatasetReceiverWrapper(receiver);
+        mHandler.post(() -> createRandomizedDatasetInternal(networkName, receiverWrapper));
     }
 
-    private static ActiveOperationalDataset createRandomizedDatasetInternal(
+    private void createRandomizedDatasetInternal(
+            String networkName, @NonNull ActiveOperationalDatasetReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().getChannelMasks(newChannelMasksReceiver(networkName, receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.getChannelMasks failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
+    private IChannelMasksReceiver newChannelMasksReceiver(
+            String networkName, ActiveOperationalDatasetReceiverWrapper receiver) {
+        return new IChannelMasksReceiver.Stub() {
+            @Override
+            public void onSuccess(int supportedChannelMask, int preferredChannelMask) {
+                ActiveOperationalDataset dataset =
+                        createRandomizedDataset(
+                                networkName,
+                                supportedChannelMask,
+                                preferredChannelMask,
+                                Instant.now(),
+                                new Random(),
+                                new SecureRandom());
+
+                receiver.onSuccess(dataset);
+            }
+
+            @Override
+            public void onError(int errorCode, String errorMessage) {
+                receiver.onError(otErrorToAndroidError(errorCode), errorMessage);
+            }
+        };
+    }
+
+    private static ActiveOperationalDataset createRandomizedDataset(
             String networkName,
             int supportedChannelMask,
+            int preferredChannelMask,
             Instant now,
             Random random,
             SecureRandom secureRandom) {
@@ -538,6 +645,7 @@
 
         final SparseArray<byte[]> channelMask = new SparseArray<>(1);
         channelMask.put(CHANNEL_PAGE_24_GHZ, channelMaskToByteArray(supportedChannelMask));
+        final int channel = selectChannel(supportedChannelMask, preferredChannelMask, random);
 
         final byte[] securityFlags = new byte[] {(byte) 0xff, (byte) 0xf8};
 
@@ -548,7 +656,7 @@
                 .setExtendedPanId(newRandomBytes(random, LENGTH_EXTENDED_PAN_ID))
                 .setPanId(panId)
                 .setNetworkName(networkName)
-                .setChannel(CHANNEL_PAGE_24_GHZ, selectRandomChannel(supportedChannelMask, random))
+                .setChannel(CHANNEL_PAGE_24_GHZ, channel)
                 .setChannelMask(channelMask)
                 .setPskc(newRandomBytes(secureRandom, LENGTH_PSKC))
                 .setNetworkKey(newRandomBytes(secureRandom, LENGTH_NETWORK_KEY))
@@ -557,6 +665,18 @@
                 .build();
     }
 
+    private static int selectChannel(
+            int supportedChannelMask, int preferredChannelMask, Random random) {
+        // If the preferred channel mask is not empty, select a random channel from it, otherwise
+        // choose one from the supported channel mask.
+        preferredChannelMask = preferredChannelMask & supportedChannelMask;
+        if (preferredChannelMask == 0) {
+            preferredChannelMask = supportedChannelMask;
+        }
+
+        return selectRandomChannel(preferredChannelMask, random);
+    }
+
     private static byte[] newRandomBytes(Random random, int length) {
         byte[] result = new byte[length];
         random.nextBytes(result);
@@ -656,9 +776,6 @@
                 return ERROR_ABORTED;
             case OT_ERROR_BUSY:
                 return ERROR_BUSY;
-            case OT_ERROR_DETACHED:
-            case OT_ERROR_INVALID_STATE:
-                return ERROR_FAILED_PRECONDITION;
             case OT_ERROR_NO_BUFS:
                 return ERROR_RESOURCE_EXHAUSTED;
             case OT_ERROR_PARSE:
@@ -672,6 +789,9 @@
                 return ERROR_UNSUPPORTED_CHANNEL;
             case OT_ERROR_THREAD_DISABLED:
                 return ERROR_THREAD_DISABLED;
+            case OT_ERROR_FAILED_PRECONDITION:
+                return ERROR_FAILED_PRECONDITION;
+            case OT_ERROR_INVALID_STATE:
             default:
                 return ERROR_INTERNAL_ERROR;
         }
diff --git a/thread/tests/unit/AndroidTest.xml b/thread/tests/unit/AndroidTest.xml
index 26813c1..d16e423 100644
--- a/thread/tests/unit/AndroidTest.xml
+++ b/thread/tests/unit/AndroidTest.xml
@@ -19,6 +19,18 @@
     <option name="test-tag" value="ThreadNetworkUnitTests" />
     <option name="test-suite-tag" value="apct" />
 
+    <!--
+        Only run tests if the device under test is SDK version 34 (Android 14) or above.
+    -->
+    <object type="module_controller"
+            class="com.android.tradefed.testtype.suite.module.Sdk34ModuleController" />
+
+    <!-- Run tests in MTS only if the Tethering Mainline module is installed. -->
+    <object type="module_controller"
+            class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+        <option name="mainline-module-package-name" value="com.google.android.tethering" />
+    </object>
+
     <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
         <option name="test-file-name" value="ThreadNetworkUnitTests.apk" />
         <option name="check-min-sdk" value="true" />
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index f626edf..4948c22 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -16,32 +16,50 @@
 
 package com.android.server.thread;
 
+import static android.Manifest.permission.ACCESS_NETWORK_STATE;
+import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
+import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
 import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
+import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
-import static com.android.testutils.TestPermissionUtil.runAsShell;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
 
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
 import android.net.ConnectivityManager;
 import android.net.NetworkAgent;
 import android.net.NetworkProvider;
 import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.Handler;
+import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
 import android.os.RemoteException;
+import android.os.UserManager;
 import android.os.test.TestLooper;
 
 import androidx.test.core.app.ApplicationProvider;
@@ -53,9 +71,15 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
+
 /** Unit tests for {@link ThreadNetworkControllerService}. */
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -80,6 +104,12 @@
                                     + "B9D351B40C0402A0FFF8");
     private static final ActiveOperationalDataset DEFAULT_ACTIVE_DATASET =
             ActiveOperationalDataset.fromThreadTlvs(DEFAULT_ACTIVE_DATASET_TLVS);
+    private static final String DEFAULT_NETWORK_NAME = "thread-wpan0";
+    private static final int OT_ERROR_NONE = 0;
+    private static final int DEFAULT_SUPPORTED_CHANNEL_MASK = 0x07FFF800; // from channel 11 to 26
+    private static final int DEFAULT_PREFERRED_CHANNEL_MASK = 0x00000800; // channel 11
+    private static final int DEFAULT_SELECTED_CHANNEL = 11;
+    private static final byte[] DEFAULT_SUPPORTED_CHANNEL_MASK_ARRAY = base16().decode("001FFFE0");
 
     @Mock private ConnectivityManager mMockConnectivityManager;
     @Mock private NetworkAgent mMockNetworkAgent;
@@ -88,30 +118,38 @@
     @Mock private InfraInterfaceController mMockInfraIfController;
     @Mock private ThreadPersistentSettings mMockPersistentSettings;
     @Mock private NsdPublisher mMockNsdPublisher;
+    @Mock private UserManager mMockUserManager;
+    @Mock private IBinder mIBinder;
     private Context mContext;
     private TestLooper mTestLooper;
     private FakeOtDaemon mFakeOtDaemon;
     private ThreadNetworkControllerService mService;
+    @Captor private ArgumentCaptor<ActiveOperationalDataset> mActiveDatasetCaptor;
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
 
-        mContext = ApplicationProvider.getApplicationContext();
+        mContext = spy(ApplicationProvider.getApplicationContext());
+        doNothing()
+                .when(mContext)
+                .enforceCallingOrSelfPermission(
+                        eq(PERMISSION_THREAD_NETWORK_PRIVILEGED), anyString());
+
         mTestLooper = new TestLooper();
         final Handler handler = new Handler(mTestLooper.getLooper());
         NetworkProvider networkProvider =
                 new NetworkProvider(mContext, mTestLooper.getLooper(), "ThreadNetworkProvider");
 
         mFakeOtDaemon = new FakeOtDaemon(handler);
-
         when(mMockTunIfController.getTunFd()).thenReturn(mMockTunFd);
 
         when(mMockPersistentSettings.get(any())).thenReturn(true);
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(false);
 
         mService =
                 new ThreadNetworkControllerService(
-                        ApplicationProvider.getApplicationContext(),
+                        mContext,
                         handler,
                         networkProvider,
                         () -> mFakeOtDaemon,
@@ -119,7 +157,8 @@
                         mMockTunIfController,
                         mMockInfraIfController,
                         mMockPersistentSettings,
-                        mMockNsdPublisher);
+                        mMockNsdPublisher,
+                        mMockUserManager);
         mService.setTestNetworkAgent(mMockNetworkAgent);
     }
 
@@ -141,9 +180,7 @@
         final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
         mFakeOtDaemon.setJoinException(new RemoteException("ot-daemon join() throws"));
 
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver));
+        mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver);
         mTestLooper.dispatchAll();
 
         verify(mockReceiver, never()).onSuccess();
@@ -155,9 +192,7 @@
         mService.initialize();
         final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
 
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver));
+        mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver);
         // Here needs to call Testlooper#dispatchAll twices because TestLooper#moveTimeForward
         // operates on only currently enqueued messages but the delayed message is posted from
         // another Handler task.
@@ -168,4 +203,134 @@
         verify(mockReceiver, times(1)).onSuccess();
         verify(mMockNetworkAgent, times(1)).register();
     }
+
+    @Test
+    public void userRestriction_initWithUserRestricted_threadIsDisabled() {
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
+
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_DISABLED);
+    }
+
+    @Test
+    public void userRestriction_initWithUserNotRestricted_threadIsEnabled() {
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(false);
+
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void userRestriction_userBecomesRestricted_stateIsDisabledButNotPersisted() {
+        AtomicReference<BroadcastReceiver> receiverRef = new AtomicReference<>();
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(false);
+        doAnswer(
+                        invocation -> {
+                            receiverRef.set((BroadcastReceiver) invocation.getArguments()[0]);
+                            return null;
+                        })
+                .when(mContext)
+                .registerReceiver(any(BroadcastReceiver.class), any(), any(), any());
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
+        receiverRef.get().onReceive(mContext, new Intent());
+        mTestLooper.dispatchAll();
+
+        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_DISABLED);
+        verify(mMockPersistentSettings, never())
+                .put(eq(ThreadPersistentSettings.THREAD_ENABLED.key), eq(false));
+    }
+
+    @Test
+    public void userRestriction_userBecomesNotRestricted_stateIsEnabledButNotPersisted() {
+        AtomicReference<BroadcastReceiver> receiverRef = new AtomicReference<>();
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
+        doAnswer(
+                        invocation -> {
+                            receiverRef.set((BroadcastReceiver) invocation.getArguments()[0]);
+                            return null;
+                        })
+                .when(mContext)
+                .registerReceiver(any(BroadcastReceiver.class), any(), any(), any());
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(false);
+        receiverRef.get().onReceive(mContext, new Intent());
+        mTestLooper.dispatchAll();
+
+        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_ENABLED);
+        verify(mMockPersistentSettings, never())
+                .put(eq(ThreadPersistentSettings.THREAD_ENABLED.key), eq(true));
+    }
+
+    @Test
+    public void userRestriction_setEnabledWhenUserRestricted_failedPreconditionError() {
+        when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
+        mService.initialize();
+
+        CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
+        mService.setEnabled(true, newOperationReceiver(setEnabledFuture));
+        mTestLooper.dispatchAll();
+
+        var thrown = assertThrows(ExecutionException.class, () -> setEnabledFuture.get());
+        ThreadNetworkException failure = (ThreadNetworkException) thrown.getCause();
+        assertThat(failure.getErrorCode()).isEqualTo(ERROR_FAILED_PRECONDITION);
+    }
+
+    private static IOperationReceiver newOperationReceiver(CompletableFuture<Void> future) {
+        return new IOperationReceiver.Stub() {
+            @Override
+            public void onSuccess() {
+                future.complete(null);
+            }
+
+            @Override
+            public void onError(int errorCode, String errorMessage) {
+                future.completeExceptionally(new ThreadNetworkException(errorCode, errorMessage));
+            }
+        };
+    }
+
+    @Test
+    public void createRandomizedDataset_succeed_activeDatasetCreated() throws Exception {
+        final IActiveOperationalDatasetReceiver mockReceiver =
+                mock(IActiveOperationalDatasetReceiver.class);
+        mFakeOtDaemon.setChannelMasks(
+                DEFAULT_SUPPORTED_CHANNEL_MASK, DEFAULT_PREFERRED_CHANNEL_MASK);
+        mFakeOtDaemon.setChannelMasksReceiverOtError(OT_ERROR_NONE);
+
+        mService.createRandomizedDataset(DEFAULT_NETWORK_NAME, mockReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, never()).onError(anyInt(), anyString());
+        verify(mockReceiver, times(1)).onSuccess(mActiveDatasetCaptor.capture());
+        ActiveOperationalDataset activeDataset = mActiveDatasetCaptor.getValue();
+        assertThat(activeDataset.getNetworkName()).isEqualTo(DEFAULT_NETWORK_NAME);
+        assertThat(activeDataset.getChannelMask().size()).isEqualTo(1);
+        assertThat(activeDataset.getChannelMask().get(CHANNEL_PAGE_24_GHZ))
+                .isEqualTo(DEFAULT_SUPPORTED_CHANNEL_MASK_ARRAY);
+        assertThat(activeDataset.getChannel()).isEqualTo(DEFAULT_SELECTED_CHANNEL);
+    }
+
+    @Test
+    public void createRandomizedDataset_otDaemonRemoteFailure_returnsPreconditionError()
+            throws Exception {
+        final IActiveOperationalDatasetReceiver mockReceiver =
+                mock(IActiveOperationalDatasetReceiver.class);
+        mFakeOtDaemon.setChannelMasksReceiverOtError(OT_ERROR_INVALID_STATE);
+        when(mockReceiver.asBinder()).thenReturn(mIBinder);
+
+        mService.createRandomizedDataset(DEFAULT_NETWORK_NAME, mockReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, never()).onSuccess(any(ActiveOperationalDataset.class));
+        verify(mockReceiver, times(1)).onError(eq(ERROR_INTERNAL_ERROR), anyString());
+    }
 }