[adbwifi] Add A_STLS command.
This command will be sent by adbd to notify the client that the
connection will be over TLS.
When client connects, it will send the CNXN packet, as usual. If the
server connection has TLS enabled, it will send the A_STLS packet
(regardless of whether auth is required). At this point, the client's
only valid response is to send a A_STLS packet. Once both sides have
exchanged the A_STLS packet, both will start the TLS handshake.
If auth is required, then the client will receive a CertificateRequest
with a list of known public keys (SHA256 hash) that it can use in its
certificate. Otherwise, the list will be empty and the client can assume
that either any key will work, or none will work.
If the handshake was successful, the server will send the CNXN packet
and the usual adb protocol is resumed over TLS. If the handshake failed,
both sides will disconnect, as there's no point to retry because the
server's known keys have already been communicated.
Bug: 111434128
Test: WIP; will add to adb_test.py/adb_device.py.
Enable wireless debugging in the Settings, then 'adb connect
<ip>:<port>'. Connection should succeed if key is in keystore. Used
wireshark to check for packet encryption.
Change-Id: I3d60647491c6c6b92297e4f628707a6457fa9420
diff --git a/adb.cpp b/adb.cpp
index 460ddde..554a754 100644
--- a/adb.cpp
+++ b/adb.cpp
@@ -52,6 +52,7 @@
#include "adb_listeners.h"
#include "adb_unique_fd.h"
#include "adb_utils.h"
+#include "adb_wifi.h"
#include "sysdeps/chrono.h"
#include "transport.h"
@@ -140,6 +141,9 @@
case A_CLSE: tag = "CLSE"; break;
case A_WRTE: tag = "WRTE"; break;
case A_AUTH: tag = "AUTH"; break;
+ case A_STLS:
+ tag = "ATLS";
+ break;
default: tag = "????"; break;
}
@@ -209,6 +213,15 @@
android::base::Join(connection_properties, ';').c_str());
}
+void send_tls_request(atransport* t) {
+ D("Calling send_tls_request");
+ apacket* p = get_apacket();
+ p->msg.command = A_STLS;
+ p->msg.arg0 = A_STLS_VERSION;
+ p->msg.data_length = 0;
+ send_packet(p, t);
+}
+
void send_connect(atransport* t) {
D("Calling send_connect");
apacket* cp = get_apacket();
@@ -299,7 +312,12 @@
#if ADB_HOST
handle_online(t);
#else
- if (!auth_required) {
+ if (t->use_tls) {
+ // We still handshake in TLS mode. If auth_required is disabled,
+ // we'll just not verify the client's certificate. This should be the
+ // first packet the client receives to indicate the new protocol.
+ send_tls_request(t);
+ } else if (!auth_required) {
LOG(INFO) << "authentication not required";
handle_online(t);
send_connect(t);
@@ -324,8 +342,21 @@
case A_CNXN: // CONNECT(version, maxdata, "system-id-string")
handle_new_connection(t, p);
break;
+ case A_STLS: // TLS(version, "")
+ t->use_tls = true;
+#if ADB_HOST
+ send_tls_request(t);
+ adb_auth_tls_handshake(t);
+#else
+ adbd_auth_tls_handshake(t);
+#endif
+ break;
case A_AUTH:
+ // All AUTH commands are ignored in TLS mode
+ if (t->use_tls) {
+ break;
+ }
switch (p->msg.arg0) {
#if ADB_HOST
case ADB_AUTH_TOKEN:
diff --git a/adb.h b/adb.h
index 7f7dd0d..86d205c 100644
--- a/adb.h
+++ b/adb.h
@@ -44,6 +44,7 @@
#define A_CLSE 0x45534c43
#define A_WRTE 0x45545257
#define A_AUTH 0x48545541
+#define A_STLS 0x534C5453
// ADB protocol version.
// Version revision:
@@ -53,6 +54,10 @@
#define A_VERSION_SKIP_CHECKSUM 0x01000001
#define A_VERSION 0x01000001
+// Stream-based TLS protocol version
+#define A_STLS_VERSION_MIN 0x01000000
+#define A_STLS_VERSION 0x01000000
+
// Used for help/version information.
#define ADB_VERSION_MAJOR 1
#define ADB_VERSION_MINOR 0
@@ -229,6 +234,7 @@
void handle_offline(atransport* t);
void send_connect(atransport* t);
+void send_tls_request(atransport* t);
void parse_banner(const std::string&, atransport* t);
diff --git a/adb_auth.h b/adb_auth.h
index 09c3a2d..7e858dc 100644
--- a/adb_auth.h
+++ b/adb_auth.h
@@ -43,6 +43,9 @@
void send_auth_response(const char* token, size_t token_size, atransport* t);
+int adb_tls_set_certificate(SSL* ssl);
+void adb_auth_tls_handshake(atransport* t);
+
#else // !ADB_HOST
extern bool auth_required;
@@ -58,6 +61,10 @@
void send_auth_request(atransport *t);
+void adbd_auth_tls_handshake(atransport* t);
+int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key);
+bssl::UniquePtr<STACK_OF(X509_NAME)> adbd_tls_client_ca_list();
+
#endif // ADB_HOST
#endif // __ADB_AUTH_H
diff --git a/client/auth.cpp b/client/auth.cpp
index a2eff7f..8738ce7 100644
--- a/client/auth.cpp
+++ b/client/auth.cpp
@@ -30,6 +30,9 @@
#include <string>
#include <adb/crypto/rsa_2048_key.h>
+#include <adb/crypto/x509_generator.h>
+#include <adb/tls/adb_ca_list.h>
+#include <adb/tls/tls_connection.h>
#include <android-base/errors.h>
#include <android-base/file.h>
#include <android-base/stringprintf.h>
@@ -55,6 +58,7 @@
static std::map<int, std::string>& g_monitored_paths = *new std::map<int, std::string>;
using namespace adb::crypto;
+using namespace adb::tls;
static bool generate_key(const std::string& file) {
LOG(INFO) << "generate_key(" << file << ")...";
@@ -144,6 +148,7 @@
if (g_keys.find(fingerprint) != g_keys.end()) {
LOG(INFO) << "ignoring already-loaded key: " << file;
} else {
+ LOG(INFO) << "Loaded fingerprint=[" << SHA256BitsToHexString(fingerprint) << "]";
g_keys[fingerprint] = std::move(key);
}
return true;
@@ -475,3 +480,72 @@
p->msg.data_length = p->payload.size();
send_packet(p, t);
}
+
+void adb_auth_tls_handshake(atransport* t) {
+ std::thread([t]() {
+ std::shared_ptr<RSA> key = t->Key();
+ if (key == nullptr) {
+ // Can happen if !auth_required
+ LOG(INFO) << "t->auth_key not set before handshake";
+ key = t->NextKey();
+ CHECK(key);
+ }
+
+ LOG(INFO) << "Attempting to TLS handshake";
+ bool success = t->connection()->DoTlsHandshake(key.get());
+ if (success) {
+ LOG(INFO) << "Handshake succeeded. Waiting for CNXN packet...";
+ } else {
+ LOG(INFO) << "Handshake failed. Kicking transport";
+ t->Kick();
+ }
+ }).detach();
+}
+
+int adb_tls_set_certificate(SSL* ssl) {
+ LOG(INFO) << __func__;
+
+ const STACK_OF(X509_NAME)* ca_list = SSL_get_client_CA_list(ssl);
+ if (ca_list == nullptr) {
+ // Either the device doesn't know any keys, or !auth_required.
+ // So let's just try with the default certificate and see what happens.
+ LOG(INFO) << "No client CA list. Trying with default certificate.";
+ return 1;
+ }
+
+ const size_t num_cas = sk_X509_NAME_num(ca_list);
+ for (size_t i = 0; i < num_cas; ++i) {
+ auto* x509_name = sk_X509_NAME_value(ca_list, i);
+ auto adbFingerprint = ParseEncodedKeyFromCAIssuer(x509_name);
+ if (!adbFingerprint.has_value()) {
+ // This could be a real CA issuer. Unfortunately, we don't support
+ // it ATM.
+ continue;
+ }
+
+ LOG(INFO) << "Checking for fingerprint match [" << *adbFingerprint << "]";
+ auto encoded_key = SHA256HexStringToBits(*adbFingerprint);
+ if (!encoded_key.has_value()) {
+ continue;
+ }
+ // Check against our list of encoded keys for a match
+ std::lock_guard<std::mutex> lock(g_keys_mutex);
+ auto rsa_priv_key = g_keys.find(*encoded_key);
+ if (rsa_priv_key != g_keys.end()) {
+ LOG(INFO) << "Got SHA256 match on a key";
+ bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
+ CHECK(EVP_PKEY_set1_RSA(evp_pkey.get(), rsa_priv_key->second.get()));
+ auto x509 = GenerateX509Certificate(evp_pkey.get());
+ auto x509_str = X509ToPEMString(x509.get());
+ auto evp_str = Key::ToPEMString(evp_pkey.get());
+ TlsConnection::SetCertAndKey(ssl, x509_str, evp_str);
+ return 1;
+ } else {
+ LOG(INFO) << "No match for [" << *adbFingerprint << "]";
+ }
+ }
+
+ // Let's just try with the default certificate anyways, because daemon might
+ // not require auth, even though it has a list of keys.
+ return 1;
+}
diff --git a/daemon/adb_wifi.cpp b/daemon/adb_wifi.cpp
index 2d47719..bce303b 100644
--- a/daemon/adb_wifi.cpp
+++ b/daemon/adb_wifi.cpp
@@ -142,9 +142,9 @@
close_on_exec(new_fd.get());
disable_tcp_nagle(new_fd.get());
std::string serial = android::base::StringPrintf("host-%d", new_fd.get());
- // TODO: register a tls transport
- // register_socket_transport(std::move(new_fd), std::move(serial), port_, 1,
- // [](atransport*) { return ReconnectResult::Abort; });
+ register_socket_transport(
+ std::move(new_fd), std::move(serial), port_, 1,
+ [](atransport*) { return ReconnectResult::Abort; }, true);
}
}
@@ -224,4 +224,5 @@
t->auth_id = adbd_auth_tls_device_connected(auth_ctx, kAdbTransportTypeWifi, t->auth_key.data(),
t->auth_key.size());
}
+
#endif /* !HOST */
diff --git a/daemon/auth.cpp b/daemon/auth.cpp
index 22ea9ff..2edf582 100644
--- a/daemon/auth.cpp
+++ b/daemon/auth.cpp
@@ -23,10 +23,14 @@
#include <string.h>
#include <algorithm>
+#include <chrono>
#include <iomanip>
#include <map>
#include <memory>
+#include <thread>
+#include <adb/crypto/rsa_2048_key.h>
+#include <adb/tls/adb_ca_list.h>
#include <adbd_auth.h>
#include <android-base/file.h>
#include <android-base/no_destructor.h>
@@ -45,8 +49,14 @@
#include "transport.h"
#include "types.h"
+using namespace adb::crypto;
+using namespace adb::tls;
+using namespace std::chrono_literals;
+
static AdbdAuthContext* auth_ctx;
+static RSA* rsa_pkey = nullptr;
+
static void adb_disconnected(void* unused, atransport* t);
static struct adisconnect adb_disconnect = {adb_disconnected, nullptr};
@@ -93,6 +103,55 @@
&f);
}
+bssl::UniquePtr<STACK_OF(X509_NAME)> adbd_tls_client_ca_list() {
+ if (!auth_required) {
+ return nullptr;
+ }
+
+ bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list(sk_X509_NAME_new_null());
+
+ IteratePublicKeys([&](std::string_view public_key) {
+ // TODO: do we really have to support both ' ' and '\t'?
+ std::vector<std::string> split = android::base::Split(std::string(public_key), " \t");
+ uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1];
+ const std::string& pubkey = split[0];
+ if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) {
+ LOG(ERROR) << "Invalid base64 key " << pubkey;
+ return true;
+ }
+
+ RSA* key = nullptr;
+ if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) {
+ LOG(ERROR) << "Failed to parse key " << pubkey;
+ return true;
+ }
+ bssl::UniquePtr<RSA> rsa_key(key);
+
+ unsigned char* dkey = nullptr;
+ int len = i2d_RSA_PUBKEY(rsa_key.get(), &dkey);
+ if (len <= 0 || dkey == nullptr) {
+ LOG(ERROR) << "Failed to encode RSA public key";
+ return true;
+ }
+
+ uint8_t digest[SHA256_DIGEST_LENGTH];
+ // Put the encoded key in the commonName attribute of the issuer name.
+ // Note that the commonName has a max length of 64 bytes, which is less
+ // than the SHA256_DIGEST_LENGTH.
+ SHA256(dkey, len, digest);
+ OPENSSL_free(dkey);
+
+ auto digest_str = SHA256BitsToHexString(
+ std::string_view(reinterpret_cast<const char*>(&digest[0]), sizeof(digest)));
+ LOG(INFO) << "fingerprint=[" << digest_str << "]";
+ auto issuer = CreateCAIssuerFromEncodedKey(digest_str);
+ CHECK(bssl::PushToStack(ca_list.get(), std::move(issuer)));
+ return true;
+ });
+
+ return ca_list;
+}
+
bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig,
std::string* auth_key) {
bool authorized = false;
@@ -217,5 +276,89 @@
}
void adbd_notify_framework_connected_key(atransport* t) {
- adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size());
+ t->auth_id = adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size());
+}
+
+int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key) {
+ if (!auth_required) {
+ // Any key will do.
+ LOG(INFO) << __func__ << ": auth not required";
+ return 1;
+ }
+
+ bool authorized = false;
+ X509* cert = X509_STORE_CTX_get0_cert(ctx);
+ if (cert == nullptr) {
+ LOG(INFO) << "got null x509 certificate";
+ return 0;
+ }
+ bssl::UniquePtr<EVP_PKEY> evp_pkey(X509_get_pubkey(cert));
+ if (evp_pkey == nullptr) {
+ LOG(INFO) << "got null evp_pkey from x509 certificate";
+ return 0;
+ }
+
+ IteratePublicKeys([&](std::string_view public_key) {
+ // TODO: do we really have to support both ' ' and '\t'?
+ std::vector<std::string> split = android::base::Split(std::string(public_key), " \t");
+ uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1];
+ const std::string& pubkey = split[0];
+ if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) {
+ LOG(ERROR) << "Invalid base64 key " << pubkey;
+ return true;
+ }
+
+ RSA* key = nullptr;
+ if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) {
+ LOG(ERROR) << "Failed to parse key " << pubkey;
+ return true;
+ }
+
+ bool verified = false;
+ bssl::UniquePtr<EVP_PKEY> known_evp(EVP_PKEY_new());
+ EVP_PKEY_set1_RSA(known_evp.get(), key);
+ if (EVP_PKEY_cmp(known_evp.get(), evp_pkey.get())) {
+ LOG(INFO) << "Matched auth_key=" << public_key;
+ verified = true;
+ } else {
+ LOG(INFO) << "auth_key doesn't match [" << public_key << "]";
+ }
+ RSA_free(key);
+ if (verified) {
+ *auth_key = public_key;
+ authorized = true;
+ return false;
+ }
+
+ return true;
+ });
+
+ return authorized ? 1 : 0;
+}
+
+void adbd_auth_tls_handshake(atransport* t) {
+ if (rsa_pkey == nullptr) {
+ // Generate a random RSA key to feed into the X509 certificate
+ auto rsa_2048 = CreateRSA2048Key();
+ CHECK(rsa_2048.has_value());
+ rsa_pkey = EVP_PKEY_get1_RSA(rsa_2048->GetEvpPkey());
+ CHECK(rsa_pkey);
+ }
+
+ std::thread([t]() {
+ std::string auth_key;
+ if (t->connection()->DoTlsHandshake(rsa_pkey, &auth_key)) {
+ LOG(INFO) << "auth_key=" << auth_key;
+ if (t->IsTcpDevice()) {
+ t->auth_key = auth_key;
+ adbd_wifi_secure_connect(t);
+ } else {
+ adbd_auth_verified(t);
+ adbd_notify_framework_connected_key(t);
+ }
+ } else {
+ // Only allow one attempt at the handshake.
+ t->Kick();
+ }
+ }).detach();
}
diff --git a/daemon/transport_qemu.cpp b/daemon/transport_qemu.cpp
index 901efee..e458cea 100644
--- a/daemon/transport_qemu.cpp
+++ b/daemon/transport_qemu.cpp
@@ -105,8 +105,9 @@
* exchange. */
std::string serial = android::base::StringPrintf("host-%d", fd.get());
WriteFdExactly(fd.get(), _start_req, strlen(_start_req));
- register_socket_transport(std::move(fd), std::move(serial), port, 1,
- [](atransport*) { return ReconnectResult::Abort; });
+ register_socket_transport(
+ std::move(fd), std::move(serial), port, 1,
+ [](atransport*) { return ReconnectResult::Abort; }, false);
}
/* Prepare for accepting of the next ADB host connection. */
diff --git a/daemon/usb.cpp b/daemon/usb.cpp
index a9ad805..c7f8895 100644
--- a/daemon/usb.cpp
+++ b/daemon/usb.cpp
@@ -260,6 +260,12 @@
CHECK_EQ(static_cast<size_t>(rc), sizeof(notify));
}
+ virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final {
+ // TODO: support TLS for usb connections.
+ LOG(FATAL) << "Not supported yet.";
+ return false;
+ }
+
private:
void StartMonitor() {
// This is a bit of a mess.
diff --git a/protocol.txt b/protocol.txt
index f4523c4..75700a4 100644
--- a/protocol.txt
+++ b/protocol.txt
@@ -79,6 +79,14 @@
kind of unique ID (or empty), and banner is a human-readable version
or identifier string. The banner is used to transmit useful properties.
+--- STLS(type, version, "") --------------------------------------------
+
+Command constant: A_STLS
+
+The TLS message informs the recipient that the connection will be encrypted
+and will need to perform a TLS handshake. version is the current version of
+the protocol.
+
--- AUTH(type, 0, "data") ----------------------------------------------
@@ -207,6 +215,7 @@
#define A_OKAY 0x59414b4f
#define A_CLSE 0x45534c43
#define A_WRTE 0x45545257
+#define A_STLS 0x534C5453
diff --git a/transport.cpp b/transport.cpp
index 9dd6ec6..8b3461a 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -36,6 +36,9 @@
#include <set>
#include <thread>
+#include <adb/crypto/rsa_2048_key.h>
+#include <adb/crypto/x509_generator.h>
+#include <adb/tls/tls_connection.h>
#include <android-base/logging.h>
#include <android-base/parsenetaddress.h>
#include <android-base/stringprintf.h>
@@ -52,7 +55,10 @@
#include "fdevent/fdevent.h"
#include "sysdeps/chrono.h"
+using namespace adb::crypto;
+using namespace adb::tls;
using android::base::ScopedLockAssertion;
+using TlsError = TlsConnection::TlsError;
static void remove_transport(atransport* transport);
static void transport_destroy(atransport* transport);
@@ -279,18 +285,7 @@
<< "): started multiple times";
}
- read_thread_ = std::thread([this]() {
- LOG(INFO) << this->transport_name_ << ": read thread spawning";
- while (true) {
- auto packet = std::make_unique<apacket>();
- if (!underlying_->Read(packet.get())) {
- PLOG(INFO) << this->transport_name_ << ": read failed";
- break;
- }
- read_callback_(this, std::move(packet));
- }
- std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
- });
+ StartReadThread();
write_thread_ = std::thread([this]() {
LOG(INFO) << this->transport_name_ << ": write thread spawning";
@@ -319,6 +314,46 @@
started_ = true;
}
+void BlockingConnectionAdapter::StartReadThread() {
+ read_thread_ = std::thread([this]() {
+ LOG(INFO) << this->transport_name_ << ": read thread spawning";
+ while (true) {
+ auto packet = std::make_unique<apacket>();
+ if (!underlying_->Read(packet.get())) {
+ PLOG(INFO) << this->transport_name_ << ": read failed";
+ break;
+ }
+
+ bool got_stls_cmd = false;
+ if (packet->msg.command == A_STLS) {
+ got_stls_cmd = true;
+ }
+
+ read_callback_(this, std::move(packet));
+
+ // If we received the STLS packet, we are about to perform the TLS
+ // handshake. So this read thread must stop and resume after the
+ // handshake completes otherwise this will interfere in the process.
+ if (got_stls_cmd) {
+ LOG(INFO) << this->transport_name_
+ << ": Received STLS packet. Stopping read thread.";
+ return;
+ }
+ }
+ std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
+ });
+}
+
+bool BlockingConnectionAdapter::DoTlsHandshake(RSA* key, std::string* auth_key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (read_thread_.joinable()) {
+ read_thread_.join();
+ }
+ bool success = this->underlying_->DoTlsHandshake(key, auth_key);
+ StartReadThread();
+ return success;
+}
+
void BlockingConnectionAdapter::Reset() {
{
std::lock_guard<std::mutex> lock(mutex_);
@@ -388,8 +423,36 @@
return true;
}
+FdConnection::FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
+
+FdConnection::~FdConnection() {}
+
+bool FdConnection::DispatchRead(void* buf, size_t len) {
+ if (tls_ != nullptr) {
+ // The TlsConnection doesn't allow 0 byte reads
+ if (len == 0) {
+ return true;
+ }
+ return tls_->ReadFully(buf, len);
+ }
+
+ return ReadFdExactly(fd_.get(), buf, len);
+}
+
+bool FdConnection::DispatchWrite(void* buf, size_t len) {
+ if (tls_ != nullptr) {
+ // The TlsConnection doesn't allow 0 byte writes
+ if (len == 0) {
+ return true;
+ }
+ return tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(buf), len));
+ }
+
+ return WriteFdExactly(fd_.get(), buf, len);
+}
+
bool FdConnection::Read(apacket* packet) {
- if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) {
+ if (!DispatchRead(&packet->msg, sizeof(amessage))) {
D("remote local: read terminated (message)");
return false;
}
@@ -401,7 +464,7 @@
packet->payload.resize(packet->msg.data_length);
- if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
+ if (!DispatchRead(&packet->payload[0], packet->payload.size())) {
D("remote local: terminated (data)");
return false;
}
@@ -410,13 +473,13 @@
}
bool FdConnection::Write(apacket* packet) {
- if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
+ if (!DispatchWrite(&packet->msg, sizeof(packet->msg))) {
D("remote local: write terminated");
return false;
}
if (packet->msg.data_length) {
- if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
+ if (!DispatchWrite(&packet->payload[0], packet->msg.data_length)) {
D("remote local: write terminated");
return false;
}
@@ -425,6 +488,51 @@
return true;
}
+bool FdConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
+ bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
+ if (!EVP_PKEY_set1_RSA(evp_pkey.get(), key)) {
+ LOG(ERROR) << "EVP_PKEY_set1_RSA failed";
+ return false;
+ }
+ auto x509 = GenerateX509Certificate(evp_pkey.get());
+ auto x509_str = X509ToPEMString(x509.get());
+ auto evp_str = Key::ToPEMString(evp_pkey.get());
+#if ADB_HOST
+ tls_ = TlsConnection::Create(TlsConnection::Role::Client,
+#else
+ tls_ = TlsConnection::Create(TlsConnection::Role::Server,
+#endif
+ x509_str, evp_str, fd_);
+ CHECK(tls_);
+#if ADB_HOST
+ // TLS 1.3 gives the client no message if the server rejected the
+ // certificate. This will enable a check in the tls connection to check
+ // whether the client certificate got rejected. Note that this assumes
+ // that, on handshake success, the server speaks first.
+ tls_->EnableClientPostHandshakeCheck(true);
+ // Add callback to set the certificate when server issues the
+ // CertificateRequest.
+ tls_->SetCertificateCallback(adb_tls_set_certificate);
+ // Allow any server certificate
+ tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
+#else
+ // Add callback to check certificate against a list of known public keys
+ tls_->SetCertVerifyCallback(
+ [auth_key](X509_STORE_CTX* ctx) { return adbd_tls_verify_cert(ctx, auth_key); });
+ // Add the list of allowed client CA issuers
+ auto ca_list = adbd_tls_client_ca_list();
+ tls_->SetClientCAList(ca_list.get());
+#endif
+
+ auto err = tls_->DoHandshake();
+ if (err == TlsError::Success) {
+ return true;
+ }
+
+ tls_.reset();
+ return false;
+}
+
void FdConnection::Close() {
adb_shutdown(fd_.get());
fd_.reset();
@@ -750,6 +858,26 @@
}
}
+void kick_all_tcp_tls_transports() {
+ std::lock_guard<std::recursive_mutex> lock(transport_lock);
+ for (auto t : transport_list) {
+ if (t->IsTcpDevice() && t->use_tls) {
+ t->Kick();
+ }
+ }
+}
+
+#if !ADB_HOST
+void kick_all_transports_by_auth_key(std::string_view auth_key) {
+ std::lock_guard<std::recursive_mutex> lock(transport_lock);
+ for (auto t : transport_list) {
+ if (auth_key == t->auth_key) {
+ t->Kick();
+ }
+ }
+}
+#endif
+
/* the fdevent select pump is single threaded */
void register_transport(atransport* transport) {
tmsg m;
@@ -1026,6 +1154,10 @@
return protocol_version;
}
+int atransport::get_tls_version() const {
+ return tls_version;
+}
+
size_t atransport::get_max_payload() const {
return max_payload;
}
@@ -1221,8 +1353,9 @@
#endif // ADB_HOST
bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
- atransport::ReconnectCallback reconnect, int* error) {
+ atransport::ReconnectCallback reconnect, bool use_tls, int* error) {
atransport* t = new atransport(std::move(reconnect), kCsOffline);
+ t->use_tls = use_tls;
D("transport: %s init'ing for socket %d, on port %d", serial.c_str(), s.get(), port);
if (init_socket_transport(t, std::move(s), port, local) < 0) {
@@ -1360,6 +1493,15 @@
}
#if ADB_HOST
+std::shared_ptr<RSA> atransport::Key() {
+ if (keys_.empty()) {
+ return nullptr;
+ }
+
+ std::shared_ptr<RSA> result = keys_[0];
+ return result;
+}
+
std::shared_ptr<RSA> atransport::NextKey() {
if (keys_.empty()) {
LOG(INFO) << "fetching keys for transport " << this->serial_name();
@@ -1367,10 +1509,11 @@
// We should have gotten at least one key: the one that's automatically generated.
CHECK(!keys_.empty());
+ } else {
+ keys_.pop_front();
}
std::shared_ptr<RSA> result = keys_[0];
- keys_.pop_front();
return result;
}
diff --git a/transport.h b/transport.h
index 5a750ee..8a0f62a 100644
--- a/transport.h
+++ b/transport.h
@@ -43,6 +43,14 @@
typedef std::unordered_set<std::string> FeatureSet;
+namespace adb {
+namespace tls {
+
+class TlsConnection;
+
+} // namespace tls
+} // namespace adb
+
const FeatureSet& supported_features();
// Encodes and decodes FeatureSet objects into human-readable strings.
@@ -104,6 +112,8 @@
virtual void Start() = 0;
virtual void Stop() = 0;
+ virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0;
+
// Stop, and reset the device if it's a USB connection.
virtual void Reset();
@@ -128,6 +138,8 @@
virtual bool Read(apacket* packet) = 0;
virtual bool Write(apacket* packet) = 0;
+ virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0;
+
// Terminate a connection.
// This method must be thread-safe, and must cause concurrent Reads/Writes to terminate.
// Formerly known as 'Kick' in atransport.
@@ -146,9 +158,12 @@
virtual void Start() override final;
virtual void Stop() override final;
+ virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
virtual void Reset() override final;
+ private:
+ void StartReadThread() REQUIRES(mutex_);
bool started_ GUARDED_BY(mutex_) = false;
bool stopped_ GUARDED_BY(mutex_) = false;
@@ -164,16 +179,22 @@
};
struct FdConnection : public BlockingConnection {
- explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
+ explicit FdConnection(unique_fd fd);
+ ~FdConnection();
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
+ bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
void Close() override;
virtual void Reset() override final { Close(); }
private:
+ bool DispatchRead(void* buf, size_t len);
+ bool DispatchWrite(void* buf, size_t len);
+
unique_fd fd_;
+ std::unique_ptr<adb::tls::TlsConnection> tls_;
};
struct UsbConnection : public BlockingConnection {
@@ -182,6 +203,7 @@
bool Read(apacket* packet) override final;
bool Write(apacket* packet) override final;
+ bool DoTlsHandshake(RSA* key, std::string* auth_key) override final;
void Close() override final;
virtual void Reset() override final;
@@ -279,6 +301,12 @@
std::string device;
std::string devpath;
+ // If this is set, the transport will initiate the connection with a
+ // START_TLS command, instead of AUTH.
+ bool use_tls = false;
+ int tls_version = A_STLS_VERSION;
+ int get_tls_version() const;
+
#if !ADB_HOST
// Used to provide the key to the framework.
std::string auth_key;
@@ -288,6 +316,8 @@
bool IsTcpDevice() const { return type == kTransportLocal; }
#if ADB_HOST
+ // The current key being authorized.
+ std::shared_ptr<RSA> Key();
std::shared_ptr<RSA> NextKey();
void ResetKeys();
#endif
@@ -400,6 +430,10 @@
atransport* find_transport(const char* serial);
void kick_all_tcp_devices();
void kick_all_transports();
+void kick_all_tcp_tls_transports();
+#if !ADB_HOST
+void kick_all_transports_by_auth_key(std::string_view auth_key);
+#endif
void register_transport(atransport* transport);
void register_usb_transport(usb_handle* h, const char* serial,
@@ -410,7 +444,8 @@
/* cause new transports to be init'd and added to the list */
bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
- atransport::ReconnectCallback reconnect, int* error = nullptr);
+ atransport::ReconnectCallback reconnect, bool use_tls,
+ int* error = nullptr);
// This should only be used for transports with connection_state == kCsNoPerm.
void unregister_usb_transport(usb_handle* usb);
diff --git a/transport_fd.cpp b/transport_fd.cpp
index 8d2ad66..b9b4f42 100644
--- a/transport_fd.cpp
+++ b/transport_fd.cpp
@@ -155,6 +155,11 @@
thread_.join();
}
+ bool DoTlsHandshake(RSA* key, std::string* auth_key) override final {
+ LOG(FATAL) << "Not supported yet";
+ return false;
+ }
+
void WakeThread() {
uint64_t buf = 0;
if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) {
diff --git a/transport_local.cpp b/transport_local.cpp
index c726186..5ec8e16 100644
--- a/transport_local.cpp
+++ b/transport_local.cpp
@@ -126,7 +126,8 @@
};
int error;
- if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), &error)) {
+ if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), false,
+ &error)) {
if (error == EALREADY) {
*response = android::base::StringPrintf("already connected to %s", serial.c_str());
} else if (error == EPERM) {
@@ -163,8 +164,9 @@
close_on_exec(fd.get());
disable_tcp_nagle(fd.get());
std::string serial = getEmulatorSerialString(console_port);
- if (register_socket_transport(std::move(fd), std::move(serial), adb_port, 1,
- [](atransport*) { return ReconnectResult::Abort; })) {
+ if (register_socket_transport(
+ std::move(fd), std::move(serial), adb_port, 1,
+ [](atransport*) { return ReconnectResult::Abort; }, false)) {
return 0;
}
}
@@ -271,8 +273,9 @@
std::string serial = android::base::StringPrintf("host-%d", fd.get());
// We don't care about port value in "register_socket_transport" as it is used
// only from ADB_HOST. "server_socket_thread" is never called from ADB_HOST.
- register_socket_transport(std::move(fd), std::move(serial), 0, 1,
- [](atransport*) { return ReconnectResult::Abort; });
+ register_socket_transport(
+ std::move(fd), std::move(serial), 0, 1,
+ [](atransport*) { return ReconnectResult::Abort; }, false);
}
}
D("transport: server_socket_thread() exiting");
@@ -365,7 +368,7 @@
if (local) {
auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
t->SetConnection(
- std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
+ std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
std::lock_guard<std::mutex> lock(local_transports_lock);
atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
if (existing_transport != nullptr) {
diff --git a/transport_usb.cpp b/transport_usb.cpp
index 3e87522..fb81b37 100644
--- a/transport_usb.cpp
+++ b/transport_usb.cpp
@@ -171,6 +171,12 @@
return true;
}
+bool UsbConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
+ // TODO: support TLS for usb connections
+ LOG(FATAL) << "Not supported yet.";
+ return false;
+}
+
void UsbConnection::Reset() {
usb_reset(handle_);
usb_kick(handle_);