From 100d7290264338c6536739abd59879aaaa812537 Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Tue, 25 Jun 2013 12:00:34 -0700 Subject: Add ALPN support to SSL socket factory This adds the ability to use Application-Layer Protocol Negotiation (ALPN) through the SSLCertificateSocketFactory. ALPN is essentially like Next Protocol Negotiation (NPN) but negotiation is done in the clear. This allows the use of other protocols on the same port (e.g., SPDY instead of HTTP on port 80). Change-Id: Ie62926b455e252c4c98670bbbecc1eb5c6f13990 --- .../android/net/SSLCertificateSocketFactory.java | 56 +++++++++++++++++++--- core/tests/coretests/src/android/net/SSLTest.java | 16 +++---- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/core/java/android/net/SSLCertificateSocketFactory.java b/core/java/android/net/SSLCertificateSocketFactory.java index 37f04d308773..31c8edb6ebb5 100644 --- a/core/java/android/net/SSLCertificateSocketFactory.java +++ b/core/java/android/net/SSLCertificateSocketFactory.java @@ -89,6 +89,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { private TrustManager[] mTrustManagers = null; private KeyManager[] mKeyManagers = null; private byte[] mNpnProtocols = null; + private byte[] mAlpnProtocols = null; private PrivateKey mChannelIdPrivateKey = null; private final int mHandshakeTimeoutMillis; @@ -268,19 +269,42 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { * must be non-empty and of length less than 256. */ public void setNpnProtocols(byte[][] npnProtocols) { - this.mNpnProtocols = toNpnProtocolsList(npnProtocols); + this.mNpnProtocols = toLengthPrefixedList(npnProtocols); + } + + /** + * Sets the + * + * Application Layer Protocol Negotiation (ALPN) protocols that this peer + * is interested in. + * + *

For servers this is the sequence of protocols to advertise as + * supported, in order of preference. This list is sent unencrypted to + * all clients that support ALPN. + * + *

For clients this is a list of supported protocols to match against the + * server's list. If there is no protocol supported by both client and + * server then the first protocol in the client's list will be selected. + * The order of the client's protocols is otherwise insignificant. + * + * @param protocols a non-empty list of protocol byte arrays. All arrays + * must be non-empty and of length less than 256. + * @hide + */ + public void setAlpnProtocols(byte[][] protocols) { + this.mAlpnProtocols = toLengthPrefixedList(protocols); } /** * Returns an array containing the concatenation of length-prefixed byte * strings. */ - static byte[] toNpnProtocolsList(byte[]... npnProtocols) { - if (npnProtocols.length == 0) { - throw new IllegalArgumentException("npnProtocols.length == 0"); + static byte[] toLengthPrefixedList(byte[]... items) { + if (items.length == 0) { + throw new IllegalArgumentException("items.length == 0"); } int totalLength = 0; - for (byte[] s : npnProtocols) { + for (byte[] s : items) { if (s.length == 0 || s.length > 255) { throw new IllegalArgumentException("s.length == 0 || s.length > 255: " + s.length); } @@ -288,7 +312,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { } byte[] result = new byte[totalLength]; int pos = 0; - for (byte[] s : npnProtocols) { + for (byte[] s : items) { result[pos++] = (byte) s.length; for (byte b : s) { result[pos++] = b; @@ -309,6 +333,20 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { return castToOpenSSLSocket(socket).getNpnSelectedProtocol(); } + /** + * Returns the + * Application + * Layer Protocol Negotiation (ALPN) protocol selected by client and server, or null + * if no protocol was negotiated. + * + * @param socket a socket created by this factory. + * @throws IllegalArgumentException if the socket was not created by this factory. + * @hide + */ + public byte[] getAlpnSelectedProtocol(Socket socket) { + return castToOpenSSLSocket(socket).getAlpnSelectedProtocol(); + } + /** * Sets the {@link KeyManager}s to be used for connections made by this factory. */ @@ -393,6 +431,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { public Socket createSocket(Socket k, String host, int port, boolean close) throws IOException { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket(k, host, port, close); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); if (mSecure) { @@ -413,6 +452,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { public Socket createSocket() throws IOException { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket(); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); return s; @@ -431,6 +471,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket( addr, port, localAddr, localPort); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); return s; @@ -447,6 +488,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { public Socket createSocket(InetAddress addr, int port) throws IOException { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket(addr, port); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); return s; @@ -464,6 +506,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket( host, port, localAddr, localPort); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); if (mSecure) { @@ -482,6 +525,7 @@ public class SSLCertificateSocketFactory extends SSLSocketFactory { public Socket createSocket(String host, int port) throws IOException { OpenSSLSocketImpl s = (OpenSSLSocketImpl) getDelegate().createSocket(host, port); s.setNpnProtocols(mNpnProtocols); + s.setAlpnProtocols(mAlpnProtocols); s.setHandshakeTimeout(mHandshakeTimeoutMillis); s.setChannelIdPrivateKey(mChannelIdPrivateKey); if (mSecure) { diff --git a/core/tests/coretests/src/android/net/SSLTest.java b/core/tests/coretests/src/android/net/SSLTest.java index 27b699d805e4..45d28aef7974 100644 --- a/core/tests/coretests/src/android/net/SSLTest.java +++ b/core/tests/coretests/src/android/net/SSLTest.java @@ -49,35 +49,35 @@ public class SSLTest extends TestCase { // System.out.println(new String(b)); } - public void testStringsToNpnBytes() { + public void testStringsToLengthPrefixedBytes() { byte[] expected = { 6, 's', 'p', 'd', 'y', '/', '2', 8, 'h', 't', 't', 'p', '/', '1', '.', '1', }; - assertTrue(Arrays.equals(expected, SSLCertificateSocketFactory.toNpnProtocolsList( + assertTrue(Arrays.equals(expected, SSLCertificateSocketFactory.toLengthPrefixedList( new byte[] { 's', 'p', 'd', 'y', '/', '2' }, new byte[] { 'h', 't', 't', 'p', '/', '1', '.', '1' }))); } - public void testStringsToNpnBytesEmptyArray() { + public void testStringsToLengthPrefixedBytesEmptyArray() { try { - SSLCertificateSocketFactory.toNpnProtocolsList(); + SSLCertificateSocketFactory.toLengthPrefixedList(); fail(); } catch (IllegalArgumentException expected) { } } - public void testStringsToNpnBytesEmptyByteArray() { + public void testStringsToLengthPrefixedBytesEmptyByteArray() { try { - SSLCertificateSocketFactory.toNpnProtocolsList(new byte[0]); + SSLCertificateSocketFactory.toLengthPrefixedList(new byte[0]); fail(); } catch (IllegalArgumentException expected) { } } - public void testStringsToNpnBytesOversizedInput() { + public void testStringsToLengthPrefixedBytesOversizedInput() { try { - SSLCertificateSocketFactory.toNpnProtocolsList(new byte[256]); + SSLCertificateSocketFactory.toLengthPrefixedList(new byte[256]); fail(); } catch (IllegalArgumentException expected) { } -- cgit v1.2.3-59-g8ed1b