diff options
| author | 2023-02-24 19:50:15 +0000 | |
|---|---|---|
| committer | 2023-02-24 19:50:15 +0000 | |
| commit | 9fe0570eea50eb3d52f6f56ab3f698a140fca3ac (patch) | |
| tree | 031e2af7db68b89b2af52fe1b74a98abed1ca1d5 | |
| parent | 6c3beccf15a050aa5a47be2e60bbca796d856651 (diff) | |
| parent | e7cd9c3e167e0417a53e450e32a9a6e1d2108b14 (diff) | |
Merge "CDM Transport clean-up" into udc-dev
4 files changed, 415 insertions, 317 deletions
diff --git a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java index 494c5a6c0779..6a53adfeea9d 100644 --- a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java +++ b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java @@ -18,51 +18,31 @@ package com.android.server.companion.transport; import static android.Manifest.permission.DELIVER_COMPANION_MESSAGES; +import static com.android.server.companion.transport.Transport.MESSAGE_REQUEST_PERMISSION_RESTORE; + import android.annotation.NonNull; import android.annotation.Nullable; import android.annotation.SuppressLint; import android.app.ActivityManagerInternal; import android.content.Context; import android.content.pm.ApplicationInfo; -import android.content.pm.PackageManager; import android.content.pm.PackageManager.NameNotFoundException; import android.os.Binder; import android.os.Build; import android.os.ParcelFileDescriptor; -import android.util.Slog; import android.util.SparseArray; import com.android.internal.annotations.GuardedBy; import com.android.server.LocalServices; -import com.android.server.companion.securechannel.SecureChannel; - -import libcore.io.IoUtils; -import libcore.io.Streams; -import libcore.util.EmptyArray; import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; -import java.util.concurrent.atomic.AtomicInteger; @SuppressLint("LongLogTag") public class CompanionTransportManager { private static final String TAG = "CDM_CompanionTransportManager"; - // TODO: flip to false - private static final boolean DEBUG = true; - - private static final int HEADER_LENGTH = 12; - - private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN - private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES - - private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC - private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI + private static final boolean DEBUG = false; private boolean mSecureTransportEnabled = true; @@ -127,9 +107,9 @@ public class CompanionTransportManager { final Transport transport; if (isSecureTransportEnabled(associationId)) { - transport = new SecureTransport(associationId, fd); + transport = new SecureTransport(associationId, fd, mContext, mListener); } else { - transport = new RawTransport(associationId, fd); + transport = new RawTransport(associationId, fd, mContext, mListener); } transport.start(); @@ -172,296 +152,4 @@ public class CompanionTransportManager { // TODO: version comparison logic return enabled; } - - // TODO: Make Transport inner classes into standalone classes. - private abstract class Transport { - protected final int mAssociationId; - protected final InputStream mRemoteIn; - protected final OutputStream mRemoteOut; - - @GuardedBy("mPendingRequests") - protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests = - new SparseArray<>(); - protected final AtomicInteger mNextSequence = new AtomicInteger(); - - Transport(int associationId, ParcelFileDescriptor fd) { - this(associationId, - new ParcelFileDescriptor.AutoCloseInputStream(fd), - new ParcelFileDescriptor.AutoCloseOutputStream(fd)); - } - - Transport(int associationId, InputStream in, OutputStream out) { - this.mAssociationId = associationId; - this.mRemoteIn = in; - this.mRemoteOut = out; - } - - public abstract void start(); - public abstract void stop(); - - protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data) - throws IOException; - - public Future<byte[]> requestForResponse(int message, byte[] data) { - if (DEBUG) Slog.d(TAG, "Requesting for response"); - final int sequence = mNextSequence.incrementAndGet(); - final CompletableFuture<byte[]> pending = new CompletableFuture<>(); - synchronized (mPendingRequests) { - mPendingRequests.put(sequence, pending); - } - - try { - sendMessage(message, sequence, data); - } catch (IOException e) { - synchronized (mPendingRequests) { - mPendingRequests.remove(sequence); - } - pending.completeExceptionally(e); - } - - return pending; - } - - protected final void handleMessage(int message, int sequence, @NonNull byte[] data) - throws IOException { - if (DEBUG) { - Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + data.length - + " from association " + mAssociationId); - } - - if (isRequest(message)) { - try { - processRequest(message, sequence, data); - } catch (IOException e) { - Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e); - } - } else if (isResponse(message)) { - processResponse(message, sequence, data); - } else { - Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); - } - } - - private void processRequest(int message, int sequence, byte[] data) - throws IOException { - switch (message) { - case MESSAGE_REQUEST_PING: { - sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, data); - break; - } - case MESSAGE_REQUEST_PERMISSION_RESTORE: { - if (!mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WATCH) - && !Build.isDebuggable()) { - Slog.w(TAG, "Restoring permissions only supported on watches"); - sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); - break; - } - try { - mListener.onRequestPermissionRestore(data); - sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, EmptyArray.BYTE); - } catch (Exception e) { - Slog.w(TAG, "Failed to restore permissions"); - sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); - } - break; - } - default: { - Slog.w(TAG, "Unknown request 0x" + Integer.toHexString(message)); - sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); - break; - } - } - } - - private void processResponse(int message, int sequence, byte[] data) { - final CompletableFuture<byte[]> future; - synchronized (mPendingRequests) { - future = mPendingRequests.removeReturnOld(sequence); - } - if (future == null) { - Slog.w(TAG, "Ignoring unknown sequence " + sequence); - return; - } - - switch (message) { - case MESSAGE_RESPONSE_SUCCESS: { - future.complete(data); - break; - } - case MESSAGE_RESPONSE_FAILURE: { - future.completeExceptionally(new RuntimeException("Remote failure")); - break; - } - default: { - Slog.w(TAG, "Ignoring unknown response 0x" + Integer.toHexString(message)); - } - } - } - } - - private class RawTransport extends Transport { - private volatile boolean mStopped; - - RawTransport(int associationId, ParcelFileDescriptor fd) { - super(associationId, fd); - } - - @Override - public void start() { - new Thread(() -> { - try { - while (!mStopped) { - receiveMessage(); - } - } catch (IOException e) { - if (!mStopped) { - Slog.w(TAG, "Trouble during transport", e); - stop(); - } - } - }).start(); - } - - @Override - public void stop() { - mStopped = true; - - IoUtils.closeQuietly(mRemoteIn); - IoUtils.closeQuietly(mRemoteOut); - } - - @Override - protected void sendMessage(int message, int sequence, @NonNull byte[] data) - throws IOException { - if (DEBUG) { - Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + data.length - + " to association " + mAssociationId); - } - - synchronized (mRemoteOut) { - final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) - .putInt(message) - .putInt(sequence) - .putInt(data.length); - mRemoteOut.write(header.array()); - mRemoteOut.write(data); - mRemoteOut.flush(); - } - } - - private void receiveMessage() throws IOException { - final byte[] headerBytes = new byte[HEADER_LENGTH]; - Streams.readFully(mRemoteIn, headerBytes); - final ByteBuffer header = ByteBuffer.wrap(headerBytes); - final int message = header.getInt(); - final int sequence = header.getInt(); - final int length = header.getInt(); - final byte[] data = new byte[length]; - Streams.readFully(mRemoteIn, data); - - handleMessage(message, sequence, data); - } - } - - private class SecureTransport extends Transport implements SecureChannel.Callback { - private final SecureChannel mSecureChannel; - - private volatile boolean mShouldProcessRequests = false; - - private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100); - - SecureTransport(int associationId, ParcelFileDescriptor fd) { - super(associationId, fd); - mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext); - } - - @Override - public void start() { - mSecureChannel.start(); - } - - @Override - public void stop() { - mSecureChannel.stop(); - mShouldProcessRequests = false; - } - - @Override - public Future<byte[]> requestForResponse(int message, byte[] data) { - // Check if channel is secured and start securing - if (!mShouldProcessRequests) { - Slog.d(TAG, "Establishing secure connection."); - try { - mSecureChannel.establishSecureConnection(); - } catch (Exception e) { - Slog.w(TAG, "Failed to initiate secure channel handshake.", e); - onError(e); - } - } - - return super.requestForResponse(message, data); - } - - @Override - protected void sendMessage(int message, int sequence, @NonNull byte[] data) - throws IOException { - if (DEBUG) { - Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + data.length - + " to association " + mAssociationId); - } - - // Queue up a message to send - mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length) - .putInt(message) - .putInt(sequence) - .putInt(data.length) - .put(data) - .array()); - } - - @Override - public void onSecureConnection() { - mShouldProcessRequests = true; - Slog.d(TAG, "Secure connection established."); - - // TODO: find a better way to handle incoming requests than a dedicated thread. - new Thread(() -> { - try { - while (mShouldProcessRequests) { - byte[] request = mRequestQueue.poll(); - if (request != null) { - mSecureChannel.sendSecureMessage(request); - } - } - } catch (IOException e) { - onError(e); - } - }).start(); - } - - @Override - public void onSecureMessageReceived(byte[] data) { - final ByteBuffer payload = ByteBuffer.wrap(data); - final int message = payload.getInt(); - final int sequence = payload.getInt(); - final int length = payload.getInt(); - final byte[] content = new byte[length]; - payload.get(content); - - try { - handleMessage(message, sequence, content); - } catch (IOException error) { - onError(error); - } - } - - @Override - public void onError(Throwable error) { - mShouldProcessRequests = false; - Slog.e(TAG, error.getMessage(), error); - } - } } diff --git a/services/companion/java/com/android/server/companion/transport/RawTransport.java b/services/companion/java/com/android/server/companion/transport/RawTransport.java new file mode 100644 index 000000000000..7c0c7cf7ac68 --- /dev/null +++ b/services/companion/java/com/android/server/companion/transport/RawTransport.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.transport; + +import android.annotation.NonNull; +import android.content.Context; +import android.os.ParcelFileDescriptor; +import android.util.Slog; + +import com.android.server.companion.transport.CompanionTransportManager.Listener; + +import libcore.io.IoUtils; +import libcore.io.Streams; + +import java.io.IOException; +import java.nio.ByteBuffer; + +class RawTransport extends Transport { + private volatile boolean mStopped; + + RawTransport(int associationId, ParcelFileDescriptor fd, Context context, Listener listener) { + super(associationId, fd, context, listener); + } + + @Override + public void start() { + new Thread(() -> { + try { + while (!mStopped) { + receiveMessage(); + } + } catch (IOException e) { + if (!mStopped) { + Slog.w(TAG, "Trouble during transport", e); + stop(); + } + } + }).start(); + } + + @Override + public void stop() { + mStopped = true; + + IoUtils.closeQuietly(mRemoteIn); + IoUtils.closeQuietly(mRemoteOut); + } + + @Override + protected void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException { + if (DEBUG) { + Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) + + " sequence " + sequence + " length " + data.length + + " to association " + mAssociationId); + } + + synchronized (mRemoteOut) { + final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) + .putInt(message) + .putInt(sequence) + .putInt(data.length); + mRemoteOut.write(header.array()); + mRemoteOut.write(data); + mRemoteOut.flush(); + } + } + + private void receiveMessage() throws IOException { + final byte[] headerBytes = new byte[HEADER_LENGTH]; + Streams.readFully(mRemoteIn, headerBytes); + final ByteBuffer header = ByteBuffer.wrap(headerBytes); + final int message = header.getInt(); + final int sequence = header.getInt(); + final int length = header.getInt(); + final byte[] data = new byte[length]; + Streams.readFully(mRemoteIn, data); + + handleMessage(message, sequence, data); + } +} diff --git a/services/companion/java/com/android/server/companion/transport/SecureTransport.java b/services/companion/java/com/android/server/companion/transport/SecureTransport.java new file mode 100644 index 000000000000..4194130f7e84 --- /dev/null +++ b/services/companion/java/com/android/server/companion/transport/SecureTransport.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.transport; + +import android.annotation.NonNull; +import android.content.Context; +import android.os.ParcelFileDescriptor; +import android.util.Slog; + +import com.android.server.companion.securechannel.SecureChannel; +import com.android.server.companion.transport.CompanionTransportManager.Listener; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Future; + +class SecureTransport extends Transport implements SecureChannel.Callback { + private final SecureChannel mSecureChannel; + + private volatile boolean mShouldProcessRequests = false; + + private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100); + + SecureTransport(int associationId, + ParcelFileDescriptor fd, + Context context, + Listener listener) { + super(associationId, fd, context, listener); + mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, context); + } + + @Override + public void start() { + mSecureChannel.start(); + } + + @Override + public void stop() { + mSecureChannel.stop(); + mShouldProcessRequests = false; + } + + @Override + public Future<byte[]> requestForResponse(int message, byte[] data) { + // Check if channel is secured and start securing + if (!mShouldProcessRequests) { + Slog.d(TAG, "Establishing secure connection."); + try { + mSecureChannel.establishSecureConnection(); + } catch (Exception e) { + Slog.w(TAG, "Failed to initiate secure channel handshake.", e); + onError(e); + } + } + + return super.requestForResponse(message, data); + } + + @Override + protected void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException { + if (DEBUG) { + Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message) + + " sequence " + sequence + " length " + data.length + + " to association " + mAssociationId); + } + + // Queue up a message to send + mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length) + .putInt(message) + .putInt(sequence) + .putInt(data.length) + .put(data) + .array()); + } + + @Override + public void onSecureConnection() { + mShouldProcessRequests = true; + Slog.d(TAG, "Secure connection established."); + + // TODO: find a better way to handle incoming requests than a dedicated thread. + new Thread(() -> { + try { + while (mShouldProcessRequests) { + byte[] request = mRequestQueue.poll(); + if (request != null) { + mSecureChannel.sendSecureMessage(request); + } + } + } catch (IOException e) { + onError(e); + } + }).start(); + } + + @Override + public void onSecureMessageReceived(byte[] data) { + final ByteBuffer payload = ByteBuffer.wrap(data); + final int message = payload.getInt(); + final int sequence = payload.getInt(); + final int length = payload.getInt(); + final byte[] content = new byte[length]; + payload.get(content); + + try { + handleMessage(message, sequence, content); + } catch (IOException error) { + onError(error); + } + } + + @Override + public void onError(Throwable error) { + mShouldProcessRequests = false; + Slog.e(TAG, error.getMessage(), error); + } +} diff --git a/services/companion/java/com/android/server/companion/transport/Transport.java b/services/companion/java/com/android/server/companion/transport/Transport.java new file mode 100644 index 000000000000..923d4243a34c --- /dev/null +++ b/services/companion/java/com/android/server/companion/transport/Transport.java @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.transport; + +import android.annotation.NonNull; +import android.content.Context; +import android.content.pm.PackageManager; +import android.os.Build; +import android.os.ParcelFileDescriptor; +import android.util.Slog; +import android.util.SparseArray; + +import com.android.internal.annotations.GuardedBy; +import com.android.server.companion.transport.CompanionTransportManager.Listener; + +import libcore.util.EmptyArray; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + +abstract class Transport { + protected static final String TAG = "CDM_CompanionTransport"; + protected static final boolean DEBUG = Build.IS_DEBUGGABLE; + + static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN + static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES + + static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC + static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI + + protected static final int HEADER_LENGTH = 12; + + protected final int mAssociationId; + protected final InputStream mRemoteIn; + protected final OutputStream mRemoteOut; + protected final Context mContext; + + private final Listener mListener; + + private static boolean isRequest(int message) { + return (message & 0xFF000000) == 0x63000000; + } + + private static boolean isResponse(int message) { + return (message & 0xFF000000) == 0x33000000; + } + + @GuardedBy("mPendingRequests") + protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests = + new SparseArray<>(); + protected final AtomicInteger mNextSequence = new AtomicInteger(); + + Transport(int associationId, ParcelFileDescriptor fd, Context context, Listener listener) { + this.mAssociationId = associationId; + this.mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd); + this.mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd); + this.mContext = context; + this.mListener = listener; + } + + public abstract void start(); + public abstract void stop(); + + public Future<byte[]> requestForResponse(int message, byte[] data) { + if (DEBUG) Slog.d(TAG, "Requesting for response"); + final int sequence = mNextSequence.incrementAndGet(); + final CompletableFuture<byte[]> pending = new CompletableFuture<>(); + synchronized (mPendingRequests) { + mPendingRequests.put(sequence, pending); + } + + try { + sendMessage(message, sequence, data); + } catch (IOException e) { + synchronized (mPendingRequests) { + mPendingRequests.remove(sequence); + } + pending.completeExceptionally(e); + } + + return pending; + } + + protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException; + + protected final void handleMessage(int message, int sequence, @NonNull byte[] data) + throws IOException { + if (DEBUG) { + Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) + + " sequence " + sequence + " length " + data.length + + " from association " + mAssociationId); + } + + if (isRequest(message)) { + try { + processRequest(message, sequence, data); + } catch (IOException e) { + Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e); + } + } else if (isResponse(message)) { + processResponse(message, sequence, data); + } else { + Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); + } + } + + private void processRequest(int message, int sequence, byte[] data) + throws IOException { + switch (message) { + case MESSAGE_REQUEST_PING: { + sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, data); + break; + } + case MESSAGE_REQUEST_PERMISSION_RESTORE: { + if (!mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WATCH) + && !Build.isDebuggable()) { + Slog.w(TAG, "Restoring permissions only supported on watches"); + sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); + break; + } + try { + mListener.onRequestPermissionRestore(data); + sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, EmptyArray.BYTE); + } catch (Exception e) { + Slog.w(TAG, "Failed to restore permissions"); + sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); + } + break; + } + default: { + Slog.w(TAG, "Unknown request 0x" + Integer.toHexString(message)); + sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE); + break; + } + } + } + + private void processResponse(int message, int sequence, byte[] data) { + final CompletableFuture<byte[]> future; + synchronized (mPendingRequests) { + future = mPendingRequests.removeReturnOld(sequence); + } + if (future == null) { + Slog.w(TAG, "Ignoring unknown sequence " + sequence); + return; + } + + switch (message) { + case MESSAGE_RESPONSE_SUCCESS: { + future.complete(data); + break; + } + case MESSAGE_RESPONSE_FAILURE: { + future.completeExceptionally(new RuntimeException("Remote failure")); + break; + } + default: { + Slog.w(TAG, "Ignoring unknown response 0x" + Integer.toHexString(message)); + } + } + } +} |