diff options
| -rw-r--r-- | services/companion/java/com/android/server/companion/securechannel/SecureChannel.java | 99 | ||||
| -rw-r--r-- | services/companion/java/com/android/server/companion/transport/SecureTransport.java | 21 |
2 files changed, 99 insertions, 21 deletions
diff --git a/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java index 0457e9aa345d..5a3db4b18a1a 100644 --- a/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java +++ b/services/companion/java/com/android/server/companion/securechannel/SecureChannel.java @@ -53,8 +53,6 @@ public class SecureChannel { private static final int VERSION = 1; private static final int HEADER_LENGTH = 6; - private static final String HANDSHAKE_PROTOCOL = "AES_256_CBC-HMAC_SHA256"; - private final InputStream mInput; private final OutputStream mOutput; private final Callback mCallback; @@ -62,14 +60,16 @@ public class SecureChannel { private final AttestationVerifier mVerifier; private volatile boolean mStopped; - private boolean mInProgress; + private volatile boolean mInProgress; private Role mRole; + private byte[] mClientInit; private D2DHandshakeContext mHandshakeContext; private D2DConnectionContextV1 mConnectionContext; private String mAlias; private int mVerificationResult; + private boolean mPskVerified; /** @@ -202,8 +202,8 @@ public class SecureChannel { } try { - initiateHandshake(); mInProgress = true; + initiateHandshake(); } catch (BadHandleException e) { throw new SecureChannelException("Failed to initiate handshake protocol.", e); } @@ -329,12 +329,56 @@ public class SecureChannel { mRole = Role.Initiator; mHandshakeContext = D2DHandshakeContext.forInitiator(); + mClientInit = mHandshakeContext.getNextHandshakeMessage(); // Send Client Init if (DEBUG) { Slog.d(TAG, "Sending Ukey2 Client Init message"); } - sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage()); + sendMessage(MessageType.HANDSHAKE_INIT, constructHandshakeInitMessage(mClientInit)); + } + + // In an occasion where both participants try to initiate a handshake, resolve the conflict + // with a dice roll simulated by the message byte content comparison. + // The higher value wins! (a.k.a. gets to be the initiator) + private byte[] handleHandshakeCollision(byte[] handshakeInitMessage) + throws IOException, HandshakeException, BadHandleException, CryptoException { + + // First byte indicates message type; 0 = CLIENT INIT, 1 = SERVER INIT + ByteBuffer buffer = ByteBuffer.wrap(handshakeInitMessage); + boolean isClientInit = buffer.get() == 0; + byte[] handshakeMessage = new byte[buffer.remaining()]; + buffer.get(handshakeMessage); + + // If received message is Server Init or current role is Responder, then there was + // no collision. Return extracted handshake message. + if (mHandshakeContext == null || !isClientInit) { + return handshakeMessage; + } + + Slog.w(TAG, "Detected a Ukey2 handshake role collision. Negotiating a role."); + + // if received message is "larger" than the sent message, then reset the handshake context. + if (compareByteArray(mClientInit, handshakeMessage) < 0) { + Slog.d(TAG, "Assigned: Responder"); + mHandshakeContext = null; + return handshakeMessage; + } else { + Slog.d(TAG, "Assigned: Initiator; Discarding received Client Init"); + + // Wait for another init message after discarding the client init + ByteBuffer nextInitMessage = ByteBuffer.wrap(readMessage(MessageType.HANDSHAKE_INIT)); + + // Throw if this message is a Client Init again; 0 = CLIENT INIT, 1 = SERVER INIT + if (nextInitMessage.get() == 0) { + // This should never happen! + throw new HandshakeException("Failed to resolve Ukey2 handshake role collision."); + } + byte[] nextHandshakeMessage = new byte[nextInitMessage.remaining()]; + nextInitMessage.get(nextHandshakeMessage); + + return nextHandshakeMessage; + } } private void exchangeHandshake() @@ -345,8 +389,15 @@ public class SecureChannel { } // Waiting for message - byte[] handshakeMessage = readMessage(MessageType.HANDSHAKE_INIT); + byte[] handshakeInitMessage = readMessage(MessageType.HANDSHAKE_INIT); + + // Mark "in-progress" upon receiving the first message + mInProgress = true; + // Handle a potential collision where both devices tried to initiate a connection + byte[] handshakeMessage = handleHandshakeCollision(handshakeInitMessage); + + // Proceed with the rest of Ukey2 handshake if (mHandshakeContext == null) { // Server-side logic mRole = Role.Responder; mHandshakeContext = D2DHandshakeContext.forResponder(); @@ -361,7 +412,8 @@ public class SecureChannel { if (DEBUG) { Slog.d(TAG, "Sending Ukey2 Server Init message"); } - sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage()); + sendMessage(MessageType.HANDSHAKE_INIT, + constructHandshakeInitMessage(mHandshakeContext.getNextHandshakeMessage())); // Receive Client Finish if (DEBUG) { @@ -418,9 +470,9 @@ public class SecureChannel { ? Role.Responder : Role.Initiator, mPreSharedKey); - boolean authenticated = Arrays.equals(receivedAuthToken, expectedAuthToken); + mPskVerified = Arrays.equals(receivedAuthToken, expectedAuthToken); - if (!authenticated) { + if (!mPskVerified) { throw new SecureChannelException("Failed to verify the hash of pre-shared key."); } @@ -477,10 +529,21 @@ public class SecureChannel { } private boolean isSecured() { + // Is ukey-2 encrypted if (mConnectionContext == null) { return false; } - return mVerifier == null || mVerificationResult == RESULT_SUCCESS; + // Is authenticated + return mPskVerified || mVerificationResult == RESULT_SUCCESS; + } + + // First byte indicates message type; 0 = CLIENT INIT, 1 = SERVER INIT + // This information is needed to help resolve potential role collision. + private byte[] constructHandshakeInitMessage(byte[] message) { + return ByteBuffer.allocate(1 + message.length) + .put((byte) (Role.Initiator.equals(mRole) ? 0 : 1)) + .put(message) + .array(); } private byte[] constructToken(D2DHandshakeContext.Role role, byte[] authValue) @@ -494,6 +557,22 @@ public class SecureChannel { .array()); } + // Arbitrary comparator + private int compareByteArray(byte[] a, byte[] b) { + if (a == b) { + return 0; + } + if (a.length != b.length) { + return a.length - b.length; + } + for (int i = 0; i < a.length; i++) { + if (a[i] != b[i]) { + return a[i] - b[i]; + } + } + return 0; + } + private String generateAlias() { String alias; do { diff --git a/services/companion/java/com/android/server/companion/transport/SecureTransport.java b/services/companion/java/com/android/server/companion/transport/SecureTransport.java index 277bd88e21ca..2d856b9614cb 100644 --- a/services/companion/java/com/android/server/companion/transport/SecureTransport.java +++ b/services/companion/java/com/android/server/companion/transport/SecureTransport.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Future; class SecureTransport extends Transport implements SecureChannel.Callback { private final SecureChannel mSecureChannel; @@ -70,7 +69,10 @@ class SecureTransport extends Transport implements SecureChannel.Callback { @Override protected void sendMessage(int message, int sequence, @NonNull byte[] data) throws IOException { - establishSecureConnection(); + // Check if channel is secured; otherwise start securing + if (!mShouldProcessRequests) { + establishSecureConnection(); + } if (DEBUG) { Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message) @@ -90,15 +92,12 @@ class SecureTransport extends Transport implements SecureChannel.Callback { } private void establishSecureConnection() { - // Check if channel is secured and start securing - if (!mShouldProcessRequests) { - Slog.d(TAG, "Establishing secure connection."); - try { - mSecureChannel.establishSecureConnection(); - } catch (Exception e) { - Slog.w(TAG, "Failed to initiate secure channel handshake.", e); - onError(e); - } + Slog.d(TAG, "Establishing secure connection."); + try { + mSecureChannel.establishSecureConnection(); + } catch (Exception e) { + Slog.w(TAG, "Failed to initiate secure channel handshake.", e); + onError(e); } } |