diff options
| author | 2023-02-21 21:33:49 +0000 | |
|---|---|---|
| committer | 2023-02-21 21:33:49 +0000 | |
| commit | f7fd412b15caa3f6a1ed6c2a5ca1c2272fddf83c (patch) | |
| tree | 5185d9f1586ad4d055817edcd1082060dfeb1c79 | |
| parent | 0df35457506228f9f8b61035b72c7234c985251b (diff) | |
| parent | c289e31a94a959c89e43d54fb5e227e5d99570a9 (diff) | |
Merge changes from topic "cdm-secure-channel-udc-dev" into udc-dev
* changes:
Introduce hidden API to disable secure channel for back-compatibility
Integrate secure channel into CDM
Implement secure channel
Add Ukey2 dependency
11 files changed, 1060 insertions, 79 deletions
diff --git a/core/java/android/companion/CompanionDeviceManager.java b/core/java/android/companion/CompanionDeviceManager.java index 5df2d5e1de35..de4f619392c1 100644 --- a/core/java/android/companion/CompanionDeviceManager.java +++ b/core/java/android/companion/CompanionDeviceManager.java @@ -1194,6 +1194,20 @@ public final class CompanionDeviceManager { } } + /** + * Enable or disable secure transport for testing. Defaults to enabled. + * + * @param enabled true to enable. false to disable. + * @hide + */ + public void enableSecureTransport(boolean enabled) { + try { + mService.enableSecureTransport(enabled); + } catch (RemoteException e) { + throw e.rethrowFromSystemServer(); + } + } + private boolean checkFeaturePresent() { boolean featurePresent = mService != null; if (!featurePresent && DEBUG) { diff --git a/core/java/android/companion/ICompanionDeviceManager.aidl b/core/java/android/companion/ICompanionDeviceManager.aidl index 010aa8f8a504..cb4baca73ba0 100644 --- a/core/java/android/companion/ICompanionDeviceManager.aidl +++ b/core/java/android/companion/ICompanionDeviceManager.aidl @@ -88,4 +88,6 @@ interface ICompanionDeviceManager { void enableSystemDataSync(int associationId, int flags); void disableSystemDataSync(int associationId, int flags); + + void enableSecureTransport(boolean enabled); } diff --git a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java index d63384332400..2b4123af3885 100644 --- a/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java +++ b/core/tests/companiontests/src/android/companion/SystemDataTransportTest.java @@ -60,6 +60,7 @@ public class SystemDataTransportTest extends InstrumentationTestCase { mContext = getInstrumentation().getTargetContext(); mCdm = mContext.getSystemService(CompanionDeviceManager.class); mAssociationId = createAssociation(); + mCdm.enableSecureTransport(false); } @Override @@ -67,6 +68,7 @@ public class SystemDataTransportTest extends InstrumentationTestCase { super.tearDown(); mCdm.disassociate(mAssociationId); + mCdm.enableSecureTransport(true); } public void testPingHandRolled() { diff --git a/services/Android.bp b/services/Android.bp index f8097ec1bb92..6e6c55325e3d 100644 --- a/services/Android.bp +++ b/services/Android.bp @@ -195,6 +195,10 @@ java_library { "manifest_services.xml", ], + required: [ + "libukey2_jni_shared", + ], + // Uncomment to enable output of certain warnings (deprecated, unchecked) //javacflags: ["-Xlint"], } diff --git a/services/companion/Android.bp b/services/companion/Android.bp index cdeb2dcf87e9..a248d9e55a8a 100644 --- a/services/companion/Android.bp +++ b/services/companion/Android.bp @@ -24,4 +24,7 @@ java_library_static { "app-compat-annotations", "services.core", ], + static_libs: [ + "ukey2_jni", + ], } diff --git a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java index 0f2ba35bd2ab..a35cae9dffda 100644 --- a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java +++ b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java @@ -726,6 +726,11 @@ public class CompanionDeviceManagerService extends SystemService { } @Override + public void enableSecureTransport(boolean enabled) { + mTransportManager.enableSecureTransport(enabled); + } + + @Override public void notifyDeviceAppeared(int associationId) { if (DEBUG) Log.i(TAG, "notifyDevice_Appeared() id=" + associationId); diff --git a/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java new file mode 100644 index 000000000000..adaee757b96a --- /dev/null +++ b/services/companion/java/com/android/server/companion/securechannel/AttestationVerifier.java @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.securechannel; + +import static android.security.attestationverification.AttestationVerificationManager.PARAM_CHALLENGE; +import static android.security.attestationverification.AttestationVerificationManager.PROFILE_PEER_DEVICE; +import static android.security.attestationverification.AttestationVerificationManager.TYPE_CHALLENGE; + +import android.annotation.NonNull; +import android.content.Context; +import android.os.Bundle; +import android.security.attestationverification.AttestationProfile; +import android.security.attestationverification.AttestationVerificationManager; +import android.security.attestationverification.VerificationToken; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; + +/** + * Helper class to perform attestation verification synchronously. + */ +class AttestationVerifier { + private static final long ATTESTATION_VERIFICATION_TIMEOUT_SECONDS = 10; // 10 seconds + private static final String PARAM_OWNED_BY_SYSTEM = "android.key_owned_by_system"; + + private final Context mContext; + + AttestationVerifier(Context context) { + this.mContext = context; + } + + /** + * Synchronously verify remote attestation as a suitable peer device on current thread. + * + * The peer device must be owned by the Android system and be protected with appropriate + * public key that this device can verify as attestation challenge. + * + * @param remoteAttestation the full certificate chain containing attestation extension. + * @param attestationChallenge attestation challenge for authentication. + * @return true if attestation is successfully verified; false otherwise. + */ + @NonNull + public int verifyAttestation( + @NonNull byte[] remoteAttestation, + @NonNull byte[] attestationChallenge + ) throws SecureChannelException { + Bundle requirements = new Bundle(); + requirements.putByteArray(PARAM_CHALLENGE, attestationChallenge); + requirements.putBoolean(PARAM_OWNED_BY_SYSTEM, true); // Custom parameter for CDM + + // Synchronously execute attestation verification. + AtomicInteger verificationResult = new AtomicInteger(0); + CountDownLatch verificationFinished = new CountDownLatch(1); + BiConsumer<Integer, VerificationToken> onVerificationResult = (result, token) -> { + verificationResult.set(result); + verificationFinished.countDown(); + }; + + mContext.getSystemService(AttestationVerificationManager.class).verifyAttestation( + new AttestationProfile(PROFILE_PEER_DEVICE), + /* localBindingType */ TYPE_CHALLENGE, + requirements, + remoteAttestation, + Runnable::run, + onVerificationResult + ); + + boolean finished; + try { + finished = verificationFinished.await( + ATTESTATION_VERIFICATION_TIMEOUT_SECONDS, + TimeUnit.SECONDS + ); + } catch (InterruptedException e) { + throw new SecureChannelException("Attestation verification was interrupted", e); + } + + if (!finished) { + throw new SecureChannelException("Attestation verification timed out."); + } + + return verificationResult.get(); + } +} diff --git a/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java new file mode 100644 index 000000000000..18ebec4b6fd3 --- /dev/null +++ b/services/companion/java/com/android/server/companion/securechannel/KeyStoreUtils.java @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.securechannel; + +import static android.security.keystore.KeyProperties.DIGEST_SHA256; +import static android.security.keystore.KeyProperties.KEY_ALGORITHM_EC; +import static android.security.keystore.KeyProperties.PURPOSE_SIGN; +import static android.security.keystore.KeyProperties.PURPOSE_VERIFY; + +import android.security.keystore.KeyGenParameterSpec; +import android.security.keystore2.AndroidKeyStoreSpi; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.cert.Certificate; + +/** + * Utility class to help generate, store, and access key-pair for the secure channel. Uses + * Android Keystore. + */ +final class KeyStoreUtils { + private static final String TAG = "CDM_SecureChannelKeyStore"; + private static final String ANDROID_KEYSTORE = AndroidKeyStoreSpi.NAME; + + private KeyStoreUtils() {} + + /** + * Load Android keystore to be used by the secure channel. + * + * @return loaded keystore instance + */ + static KeyStore loadKeyStore() throws GeneralSecurityException { + KeyStore androidKeyStore = KeyStore.getInstance(ANDROID_KEYSTORE); + + try { + androidKeyStore.load(null); + } catch (IOException e) { + // Should not happen + throw new KeyStoreException("Failed to load Android Keystore.", e); + } + + return androidKeyStore; + } + + /** + * Fetch the certificate chain encoded as byte array in the form of concatenated + * X509 certificates. + * + * @param alias unique alias for the key-pair entry + * @return a single byte-array containing the entire certificate chain + */ + static byte[] getEncodedCertificateChain(String alias) throws GeneralSecurityException { + KeyStore ks = loadKeyStore(); + + Certificate[] certificateChain = ks.getCertificateChain(alias); + + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + for (Certificate certificate : certificateChain) { + buffer.writeBytes(certificate.getEncoded()); + } + return buffer.toByteArray(); + } + + /** + * Generate a new attestation key-pair. + * + * @param alias unique alias for the key-pair entry + * @param attestationChallenge challenge value to check against for authentication + */ + static void generateAttestationKeyPair(String alias, byte[] attestationChallenge) + throws GeneralSecurityException { + KeyGenParameterSpec parameterSpec = + new KeyGenParameterSpec.Builder(alias, PURPOSE_SIGN | PURPOSE_VERIFY) + .setAttestationChallenge(attestationChallenge) + .setDigests(DIGEST_SHA256) + .build(); + + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance( + /* algorithm */ KEY_ALGORITHM_EC, + /* provider */ ANDROID_KEYSTORE); + keyPairGenerator.initialize(parameterSpec); + keyPairGenerator.generateKeyPair(); + } + + /** + * Check if alias exists. + * + * @param alias unique alias for the key-pair entry + * @return true if given alias already exists in the keystore + */ + static boolean aliasExists(String alias) { + try { + KeyStore ks = loadKeyStore(); + return ks.containsAlias(alias); + } catch (GeneralSecurityException e) { + return false; + } + + } + + static void cleanUp(String alias) { + try { + KeyStore ks = loadKeyStore(); + + if (ks.containsAlias(alias)) { + ks.deleteEntry(alias); + } + } catch (Exception ignored) { + // Do nothing; + } + } +} diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java new file mode 100644 index 000000000000..13dba84487e3 --- /dev/null +++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java @@ -0,0 +1,543 @@ +/* + * Copyright (C) 2022 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.securechannel; + +import static android.security.attestationverification.AttestationVerificationManager.RESULT_SUCCESS; + +import android.annotation.NonNull; +import android.content.Context; +import android.os.Build; +import android.util.Slog; + +import com.google.security.cryptauth.lib.securegcm.BadHandleException; +import com.google.security.cryptauth.lib.securegcm.CryptoException; +import com.google.security.cryptauth.lib.securegcm.D2DConnectionContextV1; +import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext; +import com.google.security.cryptauth.lib.securegcm.D2DHandshakeContext.Role; +import com.google.security.cryptauth.lib.securegcm.DefaultUkey2Logger; +import com.google.security.cryptauth.lib.securegcm.HandshakeException; + +import libcore.io.IoUtils; +import libcore.io.Streams; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.util.Arrays; +import java.util.UUID; + +/** + * Data stream channel that establishes secure connection between two peer devices. + */ +public class SecureChannel { + private static final String TAG = "CDM_SecureChannel"; + private static final boolean DEBUG = Build.IS_DEBUGGABLE; + + private static final int VERSION = 1; + private static final int HEADER_LENGTH = 6; + + private static final String HANDSHAKE_PROTOCOL = "AES_256_CBC-HMAC_SHA256"; + + private final InputStream mInput; + private final OutputStream mOutput; + private final Callback mCallback; + private final byte[] mPreSharedKey; + private final AttestationVerifier mVerifier; + + private volatile boolean mStopped; + private boolean mInProgress; + + private Role mRole; + private D2DHandshakeContext mHandshakeContext; + private D2DConnectionContextV1 mConnectionContext; + + private String mAlias; + private int mVerificationResult; + + + /** + * Create a new secure channel object. This secure channel allows secure messages to be + * exchanged with unattested devices. The pre-shared key must have been distributed to both + * participants of the channel in a secure way previously. + * + * @param in input stream from which data is received + * @param out output stream from which data is sent out + * @param callback subscription to received messages from the channel + * @param preSharedKey pre-shared key to authenticate unattested participant + */ + public SecureChannel( + @NonNull final InputStream in, + @NonNull final OutputStream out, + @NonNull Callback callback, + @NonNull byte[] preSharedKey + ) { + this(in, out, callback, preSharedKey, null); + } + + /** + * Create a new secure channel object. This secure channel allows secure messages to be + * exchanged with Android devices that were authenticated and verified with an attestation key. + * + * @param in input stream from which data is received + * @param out output stream from which data is sent out + * @param callback subscription to received messages from the channel + * @param context context for fetching the Attestation Verifier Framework system service + */ + public SecureChannel( + @NonNull final InputStream in, + @NonNull final OutputStream out, + @NonNull Callback callback, + @NonNull Context context + ) { + this(in, out, callback, null, new AttestationVerifier(context)); + } + + private SecureChannel( + final InputStream in, + final OutputStream out, + Callback callback, + byte[] preSharedKey, + AttestationVerifier verifier + ) { + this.mInput = in; + this.mOutput = out; + this.mCallback = callback; + this.mPreSharedKey = preSharedKey; + this.mVerifier = verifier; + } + + /** + * Start listening for incoming messages. + */ + public void start() { + new Thread(() -> { + try { + // 1. Wait for the next handshake message and process it. + exchangeHandshake(); + + // 2. Authenticate remote actor via attestation or pre-shared key. + exchangeAuthentication(); + + // 3. Notify secure channel is ready. + mInProgress = false; + mCallback.onSecureConnection(); + + // Listen for secure messages. + while (!mStopped) { + receiveSecureMessage(); + } + } catch (Exception e) { + if (mStopped) { + return; + } + // TODO: Handle different types errors. + + Slog.e(TAG, "Secure channel encountered an error.", e); + stop(); + mCallback.onError(e); + } + }).start(); + } + + /** + * Stop listening to incoming messages and close the channel. + */ + public void stop() { + if (DEBUG) { + Slog.d(TAG, "Stopping secure channel."); + } + mStopped = true; + mInProgress = false; + + IoUtils.closeQuietly(mInput); + IoUtils.closeQuietly(mOutput); + KeyStoreUtils.cleanUp(mAlias); + } + + /** + * Start exchanging handshakes to create a secure layer asynchronously. When the handshake is + * completed successfully, then the {@link Callback#onSecureConnection()} will trigger. Any + * error that occurs during the handshake will be passed by {@link Callback#onError(Throwable)}. + * + * This method must only be called from one of the two participants. + */ + public void establishSecureConnection() throws IOException, SecureChannelException { + if (isSecured()) { + Slog.d(TAG, "Channel is already secure."); + return; + } + if (mInProgress) { + Slog.w(TAG, "Channel has already started establishing secure connection."); + return; + } + + try { + initiateHandshake(); + mInProgress = true; + } catch (BadHandleException e) { + throw new SecureChannelException("Failed to initiate handshake protocol.", e); + } + } + + /** + * Send an encrypted, authenticated message via this channel. + * + * @param data data to be sent to the other side. + * @throws IOException if the output stream fails to write given data. + */ + public void sendSecureMessage(byte[] data) throws IOException { + if (!isSecured()) { + Slog.d(TAG, "Cannot send a message without a secure connection."); + throw new IllegalStateException("Channel is not secured yet."); + } + + // Encrypt constructed message + try { + sendMessage(MessageType.SECURE_MESSAGE, data); + } catch (BadHandleException e) { + throw new SecureChannelException("Failed to encrypt data.", e); + } + } + + private void receiveSecureMessage() throws IOException, CryptoException { + // Check if channel is secured. Trigger error callback. Let user handle it. + if (!isSecured()) { + Slog.d(TAG, "Received a message without a secure connection. " + + "Message will be ignored."); + mCallback.onError(new IllegalStateException("Connection is not secure.")); + return; + } + + try { + byte[] receivedMessage = readMessage(MessageType.SECURE_MESSAGE); + mCallback.onSecureMessageReceived(receivedMessage); + } catch (SecureChannelException e) { + Slog.w(TAG, "Ignoring received message.", e); + } + } + + private byte[] readMessage(MessageType expected) + throws IOException, SecureChannelException, CryptoException { + if (DEBUG) { + if (isSecured()) { + Slog.d(TAG, "Waiting to receive next secure message."); + } else { + Slog.d(TAG, "Waiting to receive next message."); + } + } + + // TODO: Handle message timeout + + // Header is _not_ encrypted, but will be covered by MAC + final byte[] headerBytes = new byte[HEADER_LENGTH]; + Streams.readFully(mInput, headerBytes); + final ByteBuffer header = ByteBuffer.wrap(headerBytes); + final int version = header.getInt(); + final short type = header.getShort(); + + if (version != VERSION) { + Streams.skipByReading(mInput, Long.MAX_VALUE); + throw new SecureChannelException("Secure channel version mismatch. " + + "Currently on version " + VERSION + ". Skipping rest of data."); + } + + if (type != expected.mValue) { + Streams.skipByReading(mInput, Long.MAX_VALUE); + throw new SecureChannelException("Unexpected message type. Expected " + expected.name() + + "; Found " + MessageType.from(type).name() + ". Skipping rest of data."); + } + + // Length of attached data is prepended as plaintext + final byte[] lengthBytes = new byte[4]; + Streams.readFully(mInput, lengthBytes); + final int length = ByteBuffer.wrap(lengthBytes).getInt(); + + // Read data based on the length + final byte[] data; + try { + data = new byte[length]; + } catch (OutOfMemoryError error) { + throw new SecureChannelException("Payload is too large.", error); + } + + Streams.readFully(mInput, data); + if (!MessageType.shouldEncrypt(expected)) { + return data; + } + + return mConnectionContext.decodeMessageFromPeer(data, headerBytes); + } + + private void sendMessage(MessageType messageType, byte[] payload) + throws IOException, BadHandleException { + synchronized (mOutput) { + byte[] header = ByteBuffer.allocate(HEADER_LENGTH) + .putInt(VERSION) + .putShort(messageType.mValue) + .array(); + byte[] data = MessageType.shouldEncrypt(messageType) + ? mConnectionContext.encodeMessageToPeer(payload, header) + : payload; + mOutput.write(header); + mOutput.write(ByteBuffer.allocate(4) + .putInt(data.length) + .array()); + mOutput.write(data); + mOutput.flush(); + } + } + + private void initiateHandshake() throws IOException, BadHandleException { + if (mConnectionContext != null) { + Slog.d(TAG, "Ukey2 handshake is already completed."); + return; + } + + mRole = Role.Initiator; + mHandshakeContext = D2DHandshakeContext.forInitiator(DefaultUkey2Logger.INSTANCE); + + // Send Client Init + if (DEBUG) { + Slog.d(TAG, "Sending Ukey2 Client Init message"); + } + sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage()); + } + + private void exchangeHandshake() + throws IOException, HandshakeException, BadHandleException, CryptoException { + if (mConnectionContext != null) { + Slog.d(TAG, "Ukey2 handshake is already completed."); + return; + } + + // Waiting for message + byte[] handshakeMessage = readMessage(MessageType.HANDSHAKE_INIT); + + if (mHandshakeContext == null) { // Server-side logic + mRole = Role.Responder; + mHandshakeContext = D2DHandshakeContext.forResponder(DefaultUkey2Logger.INSTANCE); + + // Receive Client Init + if (DEBUG) { + Slog.d(TAG, "Receiving Ukey2 Client Init message"); + } + mHandshakeContext.parseHandshakeMessage(handshakeMessage); + + // Send Server Init + if (DEBUG) { + Slog.d(TAG, "Sending Ukey2 Server Init message"); + } + sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage()); + + // Receive Client Finish + if (DEBUG) { + Slog.d(TAG, "Receiving Ukey2 Client Finish message"); + } + mHandshakeContext.parseHandshakeMessage(readMessage(MessageType.HANDSHAKE_FINISH)); + } else { // Client-side logic + + // Receive Server Init + if (DEBUG) { + Slog.d(TAG, "Receiving Ukey2 Server Init message"); + } + mHandshakeContext.parseHandshakeMessage(handshakeMessage); + + // Send Client Finish + if (DEBUG) { + Slog.d(TAG, "Sending Ukey2 Client Finish message"); + } + sendMessage(MessageType.HANDSHAKE_FINISH, mHandshakeContext.getNextHandshakeMessage()); + } + + // Convert secrets to connection context + if (mHandshakeContext.isHandshakeComplete()) { + if (DEBUG) { + Slog.d(TAG, "Ukey2 Handshake completed successfully"); + } + mConnectionContext = mHandshakeContext.toConnectionContext(); + } else { + Slog.e(TAG, "Failed to complete Ukey2 Handshake"); + throw new IllegalStateException("Ukey2 Handshake did not complete as expected."); + } + } + + private void exchangeAuthentication() + throws IOException, GeneralSecurityException, BadHandleException, CryptoException { + if (mVerifier == null) { + exchangePreSharedKey(); + } else { + exchangeAttestation(); + } + } + + private void exchangePreSharedKey() + throws IOException, GeneralSecurityException, BadHandleException, CryptoException { + + // Exchange hashed pre-shared keys + if (DEBUG) { + Slog.d(TAG, "Exchanging pre-shared keys."); + } + sendMessage(MessageType.PRE_SHARED_KEY, constructToken(mRole, mPreSharedKey)); + byte[] receivedAuthToken = readMessage(MessageType.PRE_SHARED_KEY); + byte[] expectedAuthToken = constructToken(mRole == Role.Initiator + ? Role.Responder + : Role.Initiator, + mPreSharedKey); + boolean authenticated = Arrays.equals(receivedAuthToken, expectedAuthToken); + + if (!authenticated) { + throw new SecureChannelException("Failed to verify the hash of pre-shared key."); + } + + if (DEBUG) { + Slog.d(TAG, "The pre-shared key was successfully authenticated."); + } + } + + private void exchangeAttestation() + throws IOException, GeneralSecurityException, BadHandleException, CryptoException { + if (mVerificationResult == RESULT_SUCCESS) { + Slog.d(TAG, "Remote attestation was already verified."); + return; + } + + // Send local attestation + if (DEBUG) { + Slog.d(TAG, "Exchanging device attestation."); + } + if (mAlias == null) { + mAlias = generateAlias(); + } + byte[] localChallenge = constructToken(mRole, mConnectionContext.getSessionUnique()); + KeyStoreUtils.generateAttestationKeyPair(mAlias, localChallenge); + byte[] localAttestation = KeyStoreUtils.getEncodedCertificateChain(mAlias); + sendMessage(MessageType.ATTESTATION, localAttestation); + byte[] remoteAttestation = readMessage(MessageType.ATTESTATION); + + // Verifying remote attestation with public key local binding param + byte[] expectedChallenge = constructToken(mRole == Role.Initiator + ? Role.Responder + : Role.Initiator, + mConnectionContext.getSessionUnique()); + mVerificationResult = mVerifier.verifyAttestation(remoteAttestation, expectedChallenge); + + // Exchange attestation verification result and finish + byte[] verificationResult = ByteBuffer.allocate(4) + .putInt(mVerificationResult) + .array(); + sendMessage(MessageType.AVF_RESULT, verificationResult); + byte[] remoteVerificationResult = readMessage(MessageType.AVF_RESULT); + + if (ByteBuffer.wrap(remoteVerificationResult).getInt() != RESULT_SUCCESS) { + throw new SecureChannelException("Remote device failed to verify local attestation."); + } + + if (mVerificationResult != RESULT_SUCCESS) { + throw new SecureChannelException("Failed to verify remote attestation."); + } + + if (DEBUG) { + Slog.d(TAG, "Remote attestation was successfully verified."); + } + } + + private boolean isSecured() { + if (mConnectionContext == null) { + return false; + } + return mVerifier == null || mVerificationResult == RESULT_SUCCESS; + } + + private byte[] constructToken(D2DHandshakeContext.Role role, byte[] authValue) + throws GeneralSecurityException { + MessageDigest hash = MessageDigest.getInstance("SHA-256"); + byte[] roleUtf8 = role.name().getBytes(StandardCharsets.UTF_8); + int tokenLength = roleUtf8.length + authValue.length; + return hash.digest(ByteBuffer.allocate(tokenLength) + .put(roleUtf8) + .put(authValue) + .array()); + } + + private String generateAlias() { + String alias; + do { + alias = "secure-channel-" + UUID.randomUUID(); + } while (KeyStoreUtils.aliasExists(alias)); + return alias; + } + + private enum MessageType { + HANDSHAKE_INIT(0x4849), // HI + HANDSHAKE_FINISH(0x4846), // HF + PRE_SHARED_KEY(0x504b), // PK + ATTESTATION(0x4154), // AT + AVF_RESULT(0x5652), // VR + SECURE_MESSAGE(0x534d), // SM + UNKNOWN(0); // X + + private final short mValue; + + MessageType(int value) { + this.mValue = (short) value; + } + + static MessageType from(short value) { + for (MessageType messageType : values()) { + if (value == messageType.mValue) { + return messageType; + } + } + return UNKNOWN; + } + + // Encrypt every message besides Ukey2 handshake messages + private static boolean shouldEncrypt(MessageType type) { + return type != HANDSHAKE_INIT && type != HANDSHAKE_FINISH; + } + } + + /** + * Callback that passes securely received message to the subscribed user. + */ + public interface Callback { + /** + * Triggered after {@link SecureChannel#establishSecureConnection()} finishes exchanging + * every required handshakes to fully establish a secure connection. + */ + void onSecureConnection(); + + /** + * Callback that passes securely received and decrypted data to the subscribed user. + * + * @param data securely received plaintext data. + */ + void onSecureMessageReceived(byte[] data); + + /** + * Callback that passes error that occurred during handshakes or while listening to + * messages in the secure channel. + * + * @param error + */ + void onError(Throwable error); + } +} diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java new file mode 100644 index 000000000000..68db97e35261 --- /dev/null +++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannelException.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.companion.securechannel; + +/** + * Catch-all exception for any error in the secure channel. + */ +public class SecureChannelException extends RuntimeException { + /** + * + * @param message + */ + public SecureChannelException(String message) { + super(message); + } + + public SecureChannelException(String message, Throwable t) { + super(message, t); + } +} diff --git a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java index 6db99a0d0b73..494c5a6c0779 100644 --- a/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java +++ b/services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java @@ -34,6 +34,7 @@ import android.util.SparseArray; import com.android.internal.annotations.GuardedBy; import com.android.server.LocalServices; +import com.android.server.companion.securechannel.SecureChannel; import libcore.io.IoUtils; import libcore.io.Streams; @@ -43,6 +44,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; @@ -54,8 +57,6 @@ public class CompanionTransportManager { private static final boolean DEBUG = true; private static final int HEADER_LENGTH = 12; - // TODO: refactor message processing to use streams to remove this limit - private static final int MAX_PAYLOAD_LENGTH = 1_000_000; private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES @@ -63,6 +64,8 @@ public class CompanionTransportManager { private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI + private boolean mSecureTransportEnabled = true; + private static boolean isRequest(int message) { return (message & 0xFF000000) == 0x63000000; } @@ -122,7 +125,13 @@ public class CompanionTransportManager { detachSystemDataTransport(packageName, userId, associationId); } - final Transport transport = new Transport(associationId, fd); + final Transport transport; + if (isSecureTransportEnabled(associationId)) { + transport = new SecureTransport(associationId, fd); + } else { + transport = new RawTransport(associationId, fd); + } + transport.start(); mTransports.put(associationId, transport); } @@ -142,61 +151,65 @@ public class CompanionTransportManager { public Future<?> requestPermissionRestore(int associationId, byte[] data) { synchronized (mTransports) { final Transport transport = mTransports.get(associationId); - if (transport != null) { - return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data); - } else { + if (transport == null) { return CompletableFuture.failedFuture(new IOException("Missing transport")); } + + return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data); } } - private class Transport { - private final int mAssociationId; - - private final InputStream mRemoteIn; - private final OutputStream mRemoteOut; + /** + * @hide + */ + public void enableSecureTransport(boolean enabled) { + this.mSecureTransportEnabled = enabled; + } - private final AtomicInteger mNextSequence = new AtomicInteger(); + private boolean isSecureTransportEnabled(int associationId) { + boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled; - @GuardedBy("mPendingRequests") - private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>(); + // TODO: version comparison logic + return enabled; + } - private volatile boolean mStopped; + // TODO: Make Transport inner classes into standalone classes. + private abstract class Transport { + protected final int mAssociationId; + protected final InputStream mRemoteIn; + protected final OutputStream mRemoteOut; - public Transport(int associationId, ParcelFileDescriptor fd) { - mAssociationId = associationId; - mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd); - mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd); + @GuardedBy("mPendingRequests") + protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests = + new SparseArray<>(); + protected final AtomicInteger mNextSequence = new AtomicInteger(); + + Transport(int associationId, ParcelFileDescriptor fd) { + this(associationId, + new ParcelFileDescriptor.AutoCloseInputStream(fd), + new ParcelFileDescriptor.AutoCloseOutputStream(fd)); } - public void start() { - new Thread(() -> { - try { - while (!mStopped) { - receiveMessage(); - } - } catch (IOException e) { - if (!mStopped) { - Slog.w(TAG, "Trouble during transport", e); - stop(); - } - } - }).start(); + Transport(int associationId, InputStream in, OutputStream out) { + this.mAssociationId = associationId; + this.mRemoteIn = in; + this.mRemoteOut = out; } - public void stop() { - mStopped = true; + public abstract void start(); + public abstract void stop(); - IoUtils.closeQuietly(mRemoteIn); - IoUtils.closeQuietly(mRemoteOut); - } + protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException; public Future<byte[]> requestForResponse(int message, byte[] data) { + if (DEBUG) Slog.d(TAG, "Requesting for response"); final int sequence = mNextSequence.incrementAndGet(); final CompletableFuture<byte[]> pending = new CompletableFuture<>(); synchronized (mPendingRequests) { mPendingRequests.put(sequence, pending); } + try { sendMessage(message, sequence, data); } catch (IOException e) { @@ -205,58 +218,24 @@ public class CompanionTransportManager { } pending.completeExceptionally(e); } + return pending; } - private void sendMessage(int message, int sequence, @NonNull byte[] data) + protected final void handleMessage(int message, int sequence, @NonNull byte[] data) throws IOException { if (DEBUG) { - Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + data.length - + " to association " + mAssociationId); - } - - synchronized (mRemoteOut) { - final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) - .putInt(message) - .putInt(sequence) - .putInt(data.length); - mRemoteOut.write(header.array()); - mRemoteOut.write(data); - mRemoteOut.flush(); - } - } - - private void receiveMessage() throws IOException { - if (DEBUG) { - Slog.d(TAG, "Waiting for next message..."); - } - - final byte[] headerBytes = new byte[HEADER_LENGTH]; - Streams.readFully(mRemoteIn, headerBytes); - final ByteBuffer header = ByteBuffer.wrap(headerBytes); - final int message = header.getInt(); - final int sequence = header.getInt(); - final int length = header.getInt(); - - if (DEBUG) { Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + length + + " sequence " + sequence + " length " + data.length + " from association " + mAssociationId); } - if (length > MAX_PAYLOAD_LENGTH) { - Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message) - + " sequence " + sequence + " length " + length - + " from association " + mAssociationId + " beyond maximum length"); - Streams.skipByReading(mRemoteIn, length); - return; - } - - final byte[] data = new byte[length]; - Streams.readFully(mRemoteIn, data); if (isRequest(message)) { - processRequest(message, sequence, data); + try { + processRequest(message, sequence, data); + } catch (IOException e) { + Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e); + } } else if (isResponse(message)) { processResponse(message, sequence, data); } else { @@ -320,4 +299,169 @@ public class CompanionTransportManager { } } } + + private class RawTransport extends Transport { + private volatile boolean mStopped; + + RawTransport(int associationId, ParcelFileDescriptor fd) { + super(associationId, fd); + } + + @Override + public void start() { + new Thread(() -> { + try { + while (!mStopped) { + receiveMessage(); + } + } catch (IOException e) { + if (!mStopped) { + Slog.w(TAG, "Trouble during transport", e); + stop(); + } + } + }).start(); + } + + @Override + public void stop() { + mStopped = true; + + IoUtils.closeQuietly(mRemoteIn); + IoUtils.closeQuietly(mRemoteOut); + } + + @Override + protected void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException { + if (DEBUG) { + Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) + + " sequence " + sequence + " length " + data.length + + " to association " + mAssociationId); + } + + synchronized (mRemoteOut) { + final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) + .putInt(message) + .putInt(sequence) + .putInt(data.length); + mRemoteOut.write(header.array()); + mRemoteOut.write(data); + mRemoteOut.flush(); + } + } + + private void receiveMessage() throws IOException { + final byte[] headerBytes = new byte[HEADER_LENGTH]; + Streams.readFully(mRemoteIn, headerBytes); + final ByteBuffer header = ByteBuffer.wrap(headerBytes); + final int message = header.getInt(); + final int sequence = header.getInt(); + final int length = header.getInt(); + final byte[] data = new byte[length]; + Streams.readFully(mRemoteIn, data); + + handleMessage(message, sequence, data); + } + } + + private class SecureTransport extends Transport implements SecureChannel.Callback { + private final SecureChannel mSecureChannel; + + private volatile boolean mShouldProcessRequests = false; + + private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100); + + SecureTransport(int associationId, ParcelFileDescriptor fd) { + super(associationId, fd); + mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext); + } + + @Override + public void start() { + mSecureChannel.start(); + } + + @Override + public void stop() { + mSecureChannel.stop(); + mShouldProcessRequests = false; + } + + @Override + public Future<byte[]> requestForResponse(int message, byte[] data) { + // Check if channel is secured and start securing + if (!mShouldProcessRequests) { + Slog.d(TAG, "Establishing secure connection."); + try { + mSecureChannel.establishSecureConnection(); + } catch (Exception e) { + Slog.w(TAG, "Failed to initiate secure channel handshake.", e); + onError(e); + } + } + + return super.requestForResponse(message, data); + } + + @Override + protected void sendMessage(int message, int sequence, @NonNull byte[] data) + throws IOException { + if (DEBUG) { + Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message) + + " sequence " + sequence + " length " + data.length + + " to association " + mAssociationId); + } + + // Queue up a message to send + mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length) + .putInt(message) + .putInt(sequence) + .putInt(data.length) + .put(data) + .array()); + } + + @Override + public void onSecureConnection() { + mShouldProcessRequests = true; + Slog.d(TAG, "Secure connection established."); + + // TODO: find a better way to handle incoming requests than a dedicated thread. + new Thread(() -> { + try { + while (mShouldProcessRequests) { + byte[] request = mRequestQueue.poll(); + if (request != null) { + mSecureChannel.sendSecureMessage(request); + } + } + } catch (IOException e) { + onError(e); + } + }).start(); + } + + @Override + public void onSecureMessageReceived(byte[] data) { + final ByteBuffer payload = ByteBuffer.wrap(data); + final int message = payload.getInt(); + final int sequence = payload.getInt(); + final int length = payload.getInt(); + final byte[] content = new byte[length]; + payload.get(content); + + try { + handleMessage(message, sequence, content); + } catch (IOException error) { + onError(error); + } + } + + @Override + public void onError(Throwable error) { + mShouldProcessRequests = false; + Slog.e(TAG, error.getMessage(), error); + } + } } |