diff options
| -rw-r--r-- | services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java | 29 | ||||
| -rw-r--r-- | services/companion/java/com/android/server/companion/transport/Transport.java | 32 |
2 files changed, 46 insertions, 15 deletions
diff --git a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java index 91ba9b3749fd..74908a4613be 100644 --- a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java +++ b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java @@ -39,7 +39,9 @@ import java.io.IOException; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; @@ -59,9 +61,12 @@ public class CompanionTransportManager { @GuardedBy("mTransportsListeners") private final RemoteCallbackList<IOnTransportsChangedListener> mTransportsListeners = new RemoteCallbackList<>(); + /** Message type -> IOnMessageReceivedListener */ + @GuardedBy("mMessageListeners") @NonNull - private final SparseArray<IOnMessageReceivedListener> mMessageListeners = new SparseArray<>(); + private final SparseArray<Set<IOnMessageReceivedListener>> mMessageListeners = + new SparseArray<>(); public CompanionTransportManager(Context context, AssociationStore associationStore) { mContext = context; @@ -72,7 +77,12 @@ public class CompanionTransportManager { * Add a listener to receive callbacks when a message is received for the message type */ public void addListener(int message, @NonNull IOnMessageReceivedListener listener) { - mMessageListeners.put(message, listener); + synchronized (mMessageListeners) { + if (!mMessageListeners.contains(message)) { + mMessageListeners.put(message, new HashSet<IOnMessageReceivedListener>()); + } + mMessageListeners.get(message).add(listener); + } synchronized (mTransports) { for (int i = 0; i < mTransports.size(); i++) { mTransports.valueAt(i).addListener(message, listener); @@ -113,7 +123,12 @@ public class CompanionTransportManager { * Remove the listener to stop receiving calbacks when a message is received for the given type */ public void removeListener(int messageType, IOnMessageReceivedListener listener) { - mMessageListeners.remove(messageType); + synchronized (mMessageListeners) { + if (!mMessageListeners.contains(messageType)) { + return; + } + mMessageListeners.get(messageType).remove(listener); + } } /** @@ -315,8 +330,12 @@ public class CompanionTransportManager { } private void addMessageListenersToTransport(Transport transport) { - for (int i = 0; i < mMessageListeners.size(); i++) { - transport.addListener(mMessageListeners.keyAt(i), mMessageListeners.valueAt(i)); + synchronized (mMessageListeners) { + for (int i = 0; i < mMessageListeners.size(); i++) { + for (IOnMessageReceivedListener listener : mMessageListeners.valueAt(i)) { + transport.addListener(mMessageListeners.keyAt(i), listener); + } + } } } diff --git a/services/companion/java/com/android/server/companion/transport/Transport.java b/services/companion/java/com/android/server/companion/transport/Transport.java index 8a5774e55ce2..986bd6c91e17 100644 --- a/services/companion/java/com/android/server/companion/transport/Transport.java +++ b/services/companion/java/com/android/server/companion/transport/Transport.java @@ -40,8 +40,8 @@ import libcore.util.EmptyArray; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.HashMap; -import java.util.Map; +import java.util.HashSet; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; @@ -71,7 +71,8 @@ public abstract class Transport { * the future to allow multiple listeners to receive callbacks for the same message type, the * value of the map can be a list. */ - private final Map<Integer, IOnMessageReceivedListener> mListeners; + @GuardedBy("mListeners") + private final SparseArray<Set<IOnMessageReceivedListener>> mListeners = new SparseArray<>(); private OnTransportClosedListener mOnTransportClosed; @@ -98,7 +99,6 @@ public abstract class Transport { mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd); mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd); mContext = context; - mListeners = new HashMap<>(); } /** @@ -107,7 +107,12 @@ public abstract class Transport { * @param listener Execute when a message with the type is received */ public void addListener(int message, IOnMessageReceivedListener listener) { - mListeners.put(message, listener); + synchronized (mListeners) { + if (!mListeners.contains(message)) { + mListeners.put(message, new HashSet<IOnMessageReceivedListener>()); + } + mListeners.get(message).add(listener); + } } public int getAssociationId() { @@ -281,12 +286,19 @@ public abstract class Transport { } private void callback(int message, byte[] data) { - if (mListeners.containsKey(message)) { + Set<IOnMessageReceivedListener> listenersToCall; + synchronized (mListeners) { + if (!mListeners.contains(message)) { + return; + } + listenersToCall = mListeners.get(message); + } + Slog.d(TAG, "Message 0x" + Integer.toHexString(message) + + " is received from associationId " + mAssociationId + + ", sending data length " + data.length + " to the listener(s)."); + for (IOnMessageReceivedListener listener: listenersToCall) { try { - mListeners.get(message).onMessageReceived(getAssociationId(), data); - Slog.d(TAG, "Message 0x" + Integer.toHexString(message) - + " is received from associationId " + mAssociationId - + ", sending data length " + data.length + " to the listener."); + listener.onMessageReceived(getAssociationId(), data); } catch (RemoteException ignored) { } } |