diff options
| author | 2021-09-21 21:03:39 +0000 | |
|---|---|---|
| committer | 2021-09-21 21:03:39 +0000 | |
| commit | 89932bc5b7265f040d2c9131637a16da008201db (patch) | |
| tree | 57d0a2a45bb491b4e6626b0b5769220fdff2a98e | |
| parent | 5d4da609efeb8a8b76296d4b4afaaf76479e4cde (diff) | |
| parent | b1ce80cabff8bec04c1d6c8556a068cb016a0999 (diff) | |
Merge changes from topic "binder_presigned_keys"
* changes:
binder: Add tests for using pre-signed certificates.
binder: Add utils for (de)serializing key pairs.
| -rw-r--r-- | libs/binder/RpcTlsUtils.cpp | 89 | ||||
| -rw-r--r-- | libs/binder/include/binder/RpcKeyFormat.h | 41 | ||||
| -rw-r--r-- | libs/binder/include_tls/binder/RpcTlsUtils.h | 10 | ||||
| -rw-r--r-- | libs/binder/tests/Android.bp | 31 | ||||
| -rw-r--r-- | libs/binder/tests/RpcAuthTesting.cpp | 10 | ||||
| -rw-r--r-- | libs/binder/tests/RpcAuthTesting.h | 11 | ||||
| -rw-r--r-- | libs/binder/tests/RpcTlsUtilsTest.cpp | 115 | ||||
| -rw-r--r-- | libs/binder/tests/binderRpcTest.cpp | 191 |
8 files changed, 406 insertions, 92 deletions
diff --git a/libs/binder/RpcTlsUtils.cpp b/libs/binder/RpcTlsUtils.cpp index 483cc7c58f..f3ca02a3bd 100644 --- a/libs/binder/RpcTlsUtils.cpp +++ b/libs/binder/RpcTlsUtils.cpp @@ -25,54 +25,87 @@ namespace android { namespace { -bssl::UniquePtr<X509> fromPem(const std::vector<uint8_t>& cert) { - if (cert.size() > std::numeric_limits<int>::max()) return nullptr; - bssl::UniquePtr<BIO> certBio(BIO_new_mem_buf(cert.data(), static_cast<int>(cert.size()))); - return bssl::UniquePtr<X509>(PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr)); +static_assert(sizeof(unsigned char) == sizeof(uint8_t)); + +template <typename PemReadBioFn, + typename T = std::remove_pointer_t<std::invoke_result_t< + PemReadBioFn, BIO*, std::nullptr_t, std::nullptr_t, std::nullptr_t>>> +bssl::UniquePtr<T> fromPem(const std::vector<uint8_t>& data, PemReadBioFn fn) { + if (data.size() > std::numeric_limits<int>::max()) return nullptr; + bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(data.data(), static_cast<int>(data.size()))); + return bssl::UniquePtr<T>(fn(bio.get(), nullptr, nullptr, nullptr)); } -bssl::UniquePtr<X509> fromDer(const std::vector<uint8_t>& cert) { - if (cert.size() > std::numeric_limits<long>::max()) return nullptr; - const unsigned char* data = cert.data(); - auto expectedEnd = data + cert.size(); - bssl::UniquePtr<X509> ret(d2i_X509(nullptr, &data, static_cast<long>(cert.size()))); - if (data != expectedEnd) { - ALOGE("%s: %td bytes remaining!", __PRETTY_FUNCTION__, expectedEnd - data); +template <typename D2iFn, + typename T = std::remove_pointer_t< + std::invoke_result_t<D2iFn, std::nullptr_t, const unsigned char**, long>>> +bssl::UniquePtr<T> fromDer(const std::vector<uint8_t>& data, D2iFn fn) { + if (data.size() > std::numeric_limits<long>::max()) return nullptr; + const unsigned char* dataPtr = data.data(); + auto expectedEnd = dataPtr + data.size(); + bssl::UniquePtr<T> ret(fn(nullptr, &dataPtr, static_cast<long>(data.size()))); + if (dataPtr != expectedEnd) { + ALOGE("%s: %td bytes remaining!", __PRETTY_FUNCTION__, expectedEnd - dataPtr); return nullptr; } return ret; } +template <typename T, typename WriteBioFn = int (*)(BIO*, T*)> +std::vector<uint8_t> serialize(T* object, WriteBioFn writeBio) { + bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_mem())); + TEST_AND_RETURN({}, writeBio(bio.get(), object)); + const uint8_t* data; + size_t len; + TEST_AND_RETURN({}, BIO_mem_contents(bio.get(), &data, &len)); + return std::vector<uint8_t>(data, data + len); +} + } // namespace -bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert, +bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& data, RpcCertificateFormat format) { switch (format) { case RpcCertificateFormat::PEM: - return fromPem(cert); + return fromPem(data, PEM_read_bio_X509); case RpcCertificateFormat::DER: - return fromDer(cert); + return fromDer(data, d2i_X509); } LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format)); } std::vector<uint8_t> serializeCertificate(X509* x509, RpcCertificateFormat format) { - bssl::UniquePtr<BIO> certBio(BIO_new(BIO_s_mem())); switch (format) { - case RpcCertificateFormat::PEM: { - TEST_AND_RETURN({}, PEM_write_bio_X509(certBio.get(), x509)); - } break; - case RpcCertificateFormat::DER: { - TEST_AND_RETURN({}, i2d_X509_bio(certBio.get(), x509)); - } break; - default: { - LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format)); - } + case RpcCertificateFormat::PEM: + return serialize(x509, PEM_write_bio_X509); + case RpcCertificateFormat::DER: + return serialize(x509, i2d_X509_bio); } - const uint8_t* data; - size_t len; - TEST_AND_RETURN({}, BIO_mem_contents(certBio.get(), &data, &len)); - return std::vector<uint8_t>(data, data + len); + LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format)); +} + +bssl::UniquePtr<EVP_PKEY> deserializeUnencryptedPrivatekey(const std::vector<uint8_t>& data, + RpcKeyFormat format) { + switch (format) { + case RpcKeyFormat::PEM: + return fromPem(data, PEM_read_bio_PrivateKey); + case RpcKeyFormat::DER: + return fromDer(data, d2i_AutoPrivateKey); + } + LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format)); +} + +std::vector<uint8_t> serializeUnencryptedPrivatekey(EVP_PKEY* pkey, RpcKeyFormat format) { + switch (format) { + case RpcKeyFormat::PEM: + return serialize(pkey, [](BIO* bio, EVP_PKEY* pkey) { + return PEM_write_bio_PrivateKey(bio, pkey, nullptr /* enc */, nullptr /* kstr */, + 0 /* klen */, nullptr, nullptr); + }); + case RpcKeyFormat::DER: + return serialize(pkey, i2d_PrivateKey_bio); + } + LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format)); } } // namespace android diff --git a/libs/binder/include/binder/RpcKeyFormat.h b/libs/binder/include/binder/RpcKeyFormat.h new file mode 100644 index 0000000000..5099c2eacb --- /dev/null +++ b/libs/binder/include/binder/RpcKeyFormat.h @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Formats for serializing TLS private keys. + +#pragma once + +#include <string> + +namespace android { + +enum class RpcKeyFormat { + PEM, + DER, +}; + +static inline std::string PrintToString(RpcKeyFormat format) { + switch (format) { + case RpcKeyFormat::PEM: + return "PEM"; + case RpcKeyFormat::DER: + return "DER"; + default: + return "<unknown>"; + } +} + +} // namespace android diff --git a/libs/binder/include_tls/binder/RpcTlsUtils.h b/libs/binder/include_tls/binder/RpcTlsUtils.h index 8d07835344..591926b8e1 100644 --- a/libs/binder/include_tls/binder/RpcTlsUtils.h +++ b/libs/binder/include_tls/binder/RpcTlsUtils.h @@ -23,12 +23,20 @@ #include <openssl/ssl.h> #include <binder/RpcCertificateFormat.h> +#include <binder/RpcKeyFormat.h> namespace android { -bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert, +bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& data, RpcCertificateFormat format); std::vector<uint8_t> serializeCertificate(X509* x509, RpcCertificateFormat format); +// Deserialize an un-encrypted private key. +bssl::UniquePtr<EVP_PKEY> deserializeUnencryptedPrivatekey(const std::vector<uint8_t>& data, + RpcKeyFormat format); + +// Serialize a private key in un-encrypted form. +std::vector<uint8_t> serializeUnencryptedPrivatekey(EVP_PKEY* pkey, RpcKeyFormat format); + } // namespace android diff --git a/libs/binder/tests/Android.bp b/libs/binder/tests/Android.bp index 6f3c6e2d0a..23c1b14020 100644 --- a/libs/binder/tests/Android.bp +++ b/libs/binder/tests/Android.bp @@ -173,6 +173,37 @@ cc_test { require_root: true, } +cc_test { + name: "RpcTlsUtilsTest", + host_supported: true, + target: { + darwin: { + enabled: false, + }, + android: { + test_suites: ["vts"], + }, + }, + defaults: [ + "binder_test_defaults", + "libbinder_tls_shared_deps", + ], + srcs: [ + "RpcAuthTesting.cpp", + "RpcTlsUtilsTest.cpp", + ], + shared_libs: [ + "libbinder", + "libbase", + "libutils", + "liblog", + ], + static_libs: [ + "libbinder_tls_static", + ], + test_suites: ["general-tests", "device-tests"], +} + cc_benchmark { name: "binderRpcBenchmark", defaults: ["binder_test_defaults"], diff --git a/libs/binder/tests/RpcAuthTesting.cpp b/libs/binder/tests/RpcAuthTesting.cpp index 76f7bce863..c0587a2367 100644 --- a/libs/binder/tests/RpcAuthTesting.cpp +++ b/libs/binder/tests/RpcAuthTesting.cpp @@ -70,4 +70,14 @@ status_t RpcAuthSelfSigned::configure(SSL_CTX* ctx) { return OK; } +status_t RpcAuthPreSigned::configure(SSL_CTX* ctx) { + if (!SSL_CTX_use_PrivateKey(ctx, mPkey.get())) { + return INVALID_OPERATION; + } + if (!SSL_CTX_use_certificate(ctx, mCert.get())) { + return INVALID_OPERATION; + } + return OK; +} + } // namespace android diff --git a/libs/binder/tests/RpcAuthTesting.h b/libs/binder/tests/RpcAuthTesting.h index fdc731d01e..c3c2df4c29 100644 --- a/libs/binder/tests/RpcAuthTesting.h +++ b/libs/binder/tests/RpcAuthTesting.h @@ -35,4 +35,15 @@ private: const uint32_t mValidSeconds; }; +class RpcAuthPreSigned : public RpcAuth { +public: + RpcAuthPreSigned(bssl::UniquePtr<EVP_PKEY> pkey, bssl::UniquePtr<X509> cert) + : mPkey(std::move(pkey)), mCert(std::move(cert)) {} + status_t configure(SSL_CTX* ctx) override; + +private: + bssl::UniquePtr<EVP_PKEY> mPkey; + bssl::UniquePtr<X509> mCert; +}; + } // namespace android diff --git a/libs/binder/tests/RpcTlsUtilsTest.cpp b/libs/binder/tests/RpcTlsUtilsTest.cpp new file mode 100644 index 0000000000..9b3078d5d9 --- /dev/null +++ b/libs/binder/tests/RpcTlsUtilsTest.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <binder/RpcTlsUtils.h> +#include <gtest/gtest.h> + +#include "RpcAuthTesting.h" + +namespace android { + +std::string toDebugString(EVP_PKEY* pkey) { + bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_mem())); + int res = EVP_PKEY_print_public(bio.get(), pkey, 2, nullptr); + std::string buf = "\nEVP_PKEY_print_public -> " + std::to_string(res) + "\n"; + if (BIO_write(bio.get(), buf.data(), buf.length()) <= 0) return {}; + res = EVP_PKEY_print_private(bio.get(), pkey, 2, nullptr); + buf = "\nEVP_PKEY_print_private -> " + std::to_string(res); + if (BIO_write(bio.get(), buf.data(), buf.length()) <= 0) return {}; + const uint8_t* data; + size_t len; + if (!BIO_mem_contents(bio.get(), &data, &len)) return {}; + return std::string(reinterpret_cast<const char*>(data), len); +} + +class RpcTlsUtilsKeyTest : public testing::TestWithParam<RpcKeyFormat> { +public: + static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { + return PrintToString(info.param); + } +}; + +TEST_P(RpcTlsUtilsKeyTest, Test) { + auto pkey = makeKeyPairForSelfSignedCert(); + ASSERT_NE(nullptr, pkey); + auto pkeyData = serializeUnencryptedPrivatekey(pkey.get(), GetParam()); + auto deserializedPkey = deserializeUnencryptedPrivatekey(pkeyData, GetParam()); + ASSERT_NE(nullptr, deserializedPkey); + EXPECT_EQ(1, EVP_PKEY_cmp(pkey.get(), deserializedPkey.get())) + << "expected: " << toDebugString(pkey.get()) + << "\nactual: " << toDebugString(deserializedPkey.get()); +} + +INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsKeyTest, + testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER), + RpcTlsUtilsKeyTest::PrintParamInfo); + +class RpcTlsUtilsCertTest : public testing::TestWithParam<RpcCertificateFormat> { +public: + static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { + return PrintToString(info.param); + } +}; + +TEST_P(RpcTlsUtilsCertTest, Test) { + auto pkey = makeKeyPairForSelfSignedCert(); + ASSERT_NE(nullptr, pkey); + // Make certificate from the original key in memory + auto cert = makeSelfSignedCert(pkey.get(), kCertValidSeconds); + ASSERT_NE(nullptr, cert); + auto certData = serializeCertificate(cert.get(), GetParam()); + auto deserializedCert = deserializeCertificate(certData, GetParam()); + ASSERT_NE(nullptr, deserializedCert); + EXPECT_EQ(0, X509_cmp(cert.get(), deserializedCert.get())); +} + +INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsCertTest, + testing::Values(RpcCertificateFormat::PEM, RpcCertificateFormat::DER), + RpcTlsUtilsCertTest::PrintParamInfo); + +class RpcTlsUtilsKeyAndCertTest + : public testing::TestWithParam<std::tuple<RpcKeyFormat, RpcCertificateFormat>> { +public: + static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { + auto [keyFormat, certificateFormat] = info.param; + return "key_" + PrintToString(keyFormat) + "_cert_" + PrintToString(certificateFormat); + } +}; + +TEST_P(RpcTlsUtilsKeyAndCertTest, TestCertFromDeserializedKey) { + auto [keyFormat, certificateFormat] = GetParam(); + auto pkey = makeKeyPairForSelfSignedCert(); + ASSERT_NE(nullptr, pkey); + auto pkeyData = serializeUnencryptedPrivatekey(pkey.get(), keyFormat); + auto deserializedPkey = deserializeUnencryptedPrivatekey(pkeyData, keyFormat); + ASSERT_NE(nullptr, deserializedPkey); + + // Make certificate from deserialized key loaded from bytes + auto cert = makeSelfSignedCert(deserializedPkey.get(), kCertValidSeconds); + ASSERT_NE(nullptr, cert); + auto certData = serializeCertificate(cert.get(), certificateFormat); + auto deserializedCert = deserializeCertificate(certData, certificateFormat); + ASSERT_NE(nullptr, deserializedCert); + EXPECT_EQ(0, X509_cmp(cert.get(), deserializedCert.get())); +} + +INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsKeyAndCertTest, + testing::Combine(testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER), + testing::Values(RpcCertificateFormat::PEM, + RpcCertificateFormat::DER)), + RpcTlsUtilsKeyAndCertTest::PrintParamInfo); + +} // namespace android diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp index 2fd1a2ae98..0e7e25913b 100644 --- a/libs/binder/tests/binderRpcTest.cpp +++ b/libs/binder/tests/binderRpcTest.cpp @@ -31,6 +31,7 @@ #include <binder/ProcessState.h> #include <binder/RpcServer.h> #include <binder/RpcSession.h> +#include <binder/RpcTlsUtils.h> #include <binder/RpcTransport.h> #include <binder/RpcTransportRaw.h> #include <binder/RpcTransportTls.h> @@ -1439,37 +1440,10 @@ TEST(BinderRpc, Java) { INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()), BinderRpcSimple::PrintTestParam); -class RpcTransportTest - : public ::testing::TestWithParam< - std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>> { +class RpcTransportTestUtils { public: + using Param = std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>; using ConnectToServer = std::function<base::unique_fd()>; - static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { - auto [socketType, rpcSecurity, certificateFormat] = info.param; - auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString(); - if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat); - return ret; - } - static std::vector<ParamType> getRpcTranportTestParams() { - std::vector<RpcTransportTest::ParamType> ret; - for (auto socketType : testSocketTypes(false /* hasPreconnected */)) { - for (auto rpcSecurity : RpcSecurityValues()) { - switch (rpcSecurity) { - case RpcSecurity::RAW: { - ret.emplace_back(socketType, rpcSecurity, std::nullopt); - } break; - case RpcSecurity::TLS: { - ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM); - ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER); - } break; - } - } - } - return ret; - } - void TearDown() override { - for (auto& server : mServers) server->shutdownAndWait(); - } // A server that handles client socket connections. class Server { @@ -1477,8 +1451,10 @@ public: explicit Server() {} Server(Server&&) = default; ~Server() { shutdownAndWait(); } - [[nodiscard]] AssertionResult setUp() { - auto [socketType, rpcSecurity, certificateFormat] = GetParam(); + [[nodiscard]] AssertionResult setUp( + const Param& param, + std::unique_ptr<RpcAuth> auth = std::make_unique<RpcAuthSelfSigned>()) { + auto [socketType, rpcSecurity, certificateFormat] = param; auto rpcServer = RpcServer::make(newFactory(rpcSecurity)); rpcServer->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction(); switch (socketType) { @@ -1529,7 +1505,7 @@ public: } mFd = rpcServer->releaseServer(); if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd"; - mCtx = newFactory(rpcSecurity, mCertVerifier)->newServerCtx(); + mCtx = newFactory(rpcSecurity, mCertVerifier, std::move(auth))->newServerCtx(); if (mCtx == nullptr) return AssertionFailure() << "newServerCtx"; mSetup = true; return AssertionSuccess(); @@ -1608,8 +1584,8 @@ public: public: explicit Client(ConnectToServer connectToServer) : mConnectToServer(connectToServer) {} Client(Client&&) = default; - [[nodiscard]] AssertionResult setUp() { - auto [socketType, rpcSecurity, certificateFormat] = GetParam(); + [[nodiscard]] AssertionResult setUp(const Param& param) { + auto [socketType, rpcSecurity, certificateFormat] = param; mFdTrigger = FdTrigger::make(); mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx(); if (mCtx == nullptr) return AssertionFailure() << "newClientCtx"; @@ -1662,8 +1638,9 @@ public: // Make A trust B. template <typename A, typename B> - status_t trust(A* a, B* b) { - auto [socketType, rpcSecurity, certificateFormat] = GetParam(); + static status_t trust(RpcSecurity rpcSecurity, + std::optional<RpcCertificateFormat> certificateFormat, const A& a, + const B& b) { if (rpcSecurity != RpcSecurity::TLS) return OK; LOG_ALWAYS_FATAL_IF(!certificateFormat.has_value()); auto bCert = b->getCtx()->getCertificate(*certificateFormat); @@ -1671,15 +1648,48 @@ public: } static constexpr const char* kMessage = "hello"; - std::vector<std::unique_ptr<Server>> mServers; +}; + +class RpcTransportTest : public testing::TestWithParam<RpcTransportTestUtils::Param> { +public: + using Server = RpcTransportTestUtils::Server; + using Client = RpcTransportTestUtils::Client; + static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { + auto [socketType, rpcSecurity, certificateFormat] = info.param; + auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString(); + if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat); + return ret; + } + static std::vector<ParamType> getRpcTranportTestParams() { + std::vector<ParamType> ret; + for (auto socketType : testSocketTypes(false /* hasPreconnected */)) { + for (auto rpcSecurity : RpcSecurityValues()) { + switch (rpcSecurity) { + case RpcSecurity::RAW: { + ret.emplace_back(socketType, rpcSecurity, std::nullopt); + } break; + case RpcSecurity::TLS: { + ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM); + ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER); + } break; + } + } + } + return ret; + } + template <typename A, typename B> + status_t trust(const A& a, const B& b) { + auto [socketType, rpcSecurity, certificateFormat] = GetParam(); + return RpcTransportTestUtils::trust(rpcSecurity, certificateFormat, a, b); + } }; TEST_P(RpcTransportTest, GoodCertificate) { - auto server = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(server->setUp()); + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); Client client(server->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); ASSERT_EQ(OK, trust(&client, server)); ASSERT_EQ(OK, trust(server, &client)); @@ -1689,13 +1699,13 @@ TEST_P(RpcTransportTest, GoodCertificate) { } TEST_P(RpcTransportTest, MultipleClients) { - auto server = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(server->setUp()); + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); std::vector<Client> clients; for (int i = 0; i < 2; i++) { auto& client = clients.emplace_back(server->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); ASSERT_EQ(OK, trust(&client, server)); ASSERT_EQ(OK, trust(server, &client)); } @@ -1707,11 +1717,11 @@ TEST_P(RpcTransportTest, MultipleClients) { TEST_P(RpcTransportTest, UntrustedServer) { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); - auto untrustedServer = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(untrustedServer->setUp()); + auto untrustedServer = std::make_unique<Server>(); + ASSERT_TRUE(untrustedServer->setUp(GetParam())); Client client(untrustedServer->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); ASSERT_EQ(OK, trust(untrustedServer, &client)); @@ -1724,14 +1734,14 @@ TEST_P(RpcTransportTest, UntrustedServer) { } TEST_P(RpcTransportTest, MaliciousServer) { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); - auto validServer = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(validServer->setUp()); + auto validServer = std::make_unique<Server>(); + ASSERT_TRUE(validServer->setUp(GetParam())); - auto maliciousServer = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(maliciousServer->setUp()); + auto maliciousServer = std::make_unique<Server>(); + ASSERT_TRUE(maliciousServer->setUp(GetParam())); Client client(maliciousServer->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); ASSERT_EQ(OK, trust(&client, validServer)); ASSERT_EQ(OK, trust(validServer, &client)); @@ -1747,11 +1757,11 @@ TEST_P(RpcTransportTest, MaliciousServer) { TEST_P(RpcTransportTest, UntrustedClient) { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); - auto server = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(server->setUp()); + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); Client client(server->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); ASSERT_EQ(OK, trust(&client, server)); @@ -1766,13 +1776,13 @@ TEST_P(RpcTransportTest, UntrustedClient) { TEST_P(RpcTransportTest, MaliciousClient) { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); - auto server = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(server->setUp()); + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); Client validClient(server->getConnectToServerFn()); - ASSERT_TRUE(validClient.setUp()); + ASSERT_TRUE(validClient.setUp(GetParam())); Client maliciousClient(server->getConnectToServerFn()); - ASSERT_TRUE(maliciousClient.setUp()); + ASSERT_TRUE(maliciousClient.setUp(GetParam())); ASSERT_EQ(OK, trust(&validClient, server)); ASSERT_EQ(OK, trust(&maliciousClient, server)); @@ -1790,7 +1800,7 @@ TEST_P(RpcTransportTest, Trigger) { std::condition_variable writeCv; bool shouldContinueWriting = false; auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { - std::string message(kMessage); + std::string message(RpcTransportTestUtils::kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); if (status != OK) return AssertionFailure() << statusToString(status); @@ -1810,12 +1820,12 @@ TEST_P(RpcTransportTest, Trigger) { return AssertionSuccess(); }; - auto server = mServers.emplace_back(std::make_unique<Server>()).get(); - ASSERT_TRUE(server->setUp()); + auto server = std::make_unique<Server>(); + ASSERT_TRUE(server->setUp(GetParam())); // Set up client Client client(server->getConnectToServerFn()); - ASSERT_TRUE(client.setUp()); + ASSERT_TRUE(client.setUp(GetParam())); // Exchange keys ASSERT_EQ(OK, trust(&client, server)); @@ -1828,7 +1838,7 @@ TEST_P(RpcTransportTest, Trigger) { ASSERT_TRUE(client.setUpTransport()); // read the first message. This ensures that server has finished handshake and start handling // client fd. Server thread should pause at writeCv.wait_for(). - ASSERT_TRUE(client.readMessage(kMessage)); + ASSERT_TRUE(client.readMessage(RpcTransportTestUtils::kMessage)); // Trigger server shutdown after server starts handling client FD. This ensures that the second // write is on an FdTrigger that has been shut down. server->shutdown(); @@ -1848,6 +1858,61 @@ INSTANTIATE_TEST_CASE_P(BinderRpc, RpcTransportTest, ::testing::ValuesIn(RpcTransportTest::getRpcTranportTestParams()), RpcTransportTest::PrintParamInfo); +class RpcTransportTlsKeyTest + : public testing::TestWithParam<std::tuple<SocketType, RpcCertificateFormat, RpcKeyFormat>> { +public: + template <typename A, typename B> + status_t trust(const A& a, const B& b) { + auto [socketType, certificateFormat, keyFormat] = GetParam(); + return RpcTransportTestUtils::trust(RpcSecurity::TLS, certificateFormat, a, b); + } + static std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) { + auto [socketType, certificateFormat, keyFormat] = info.param; + auto ret = PrintToString(socketType) + "_certificate_" + PrintToString(certificateFormat) + + "_key_" + PrintToString(keyFormat); + return ret; + }; +}; + +TEST_P(RpcTransportTlsKeyTest, PreSignedCertificate) { + auto [socketType, certificateFormat, keyFormat] = GetParam(); + + std::vector<uint8_t> pkeyData, certData; + { + auto pkey = makeKeyPairForSelfSignedCert(); + ASSERT_NE(nullptr, pkey); + auto cert = makeSelfSignedCert(pkey.get(), kCertValidSeconds); + ASSERT_NE(nullptr, cert); + pkeyData = serializeUnencryptedPrivatekey(pkey.get(), keyFormat); + certData = serializeCertificate(cert.get(), certificateFormat); + } + + auto desPkey = deserializeUnencryptedPrivatekey(pkeyData, keyFormat); + auto desCert = deserializeCertificate(certData, certificateFormat); + auto auth = std::make_unique<RpcAuthPreSigned>(std::move(desPkey), std::move(desCert)); + auto utilsParam = + std::make_tuple(socketType, RpcSecurity::TLS, std::make_optional(certificateFormat)); + + auto server = std::make_unique<RpcTransportTestUtils::Server>(); + ASSERT_TRUE(server->setUp(utilsParam, std::move(auth))); + + RpcTransportTestUtils::Client client(server->getConnectToServerFn()); + ASSERT_TRUE(client.setUp(utilsParam)); + + ASSERT_EQ(OK, trust(&client, server)); + ASSERT_EQ(OK, trust(server, &client)); + + server->start(); + client.run(); +} + +INSTANTIATE_TEST_CASE_P( + BinderRpc, RpcTransportTlsKeyTest, + testing::Combine(testing::ValuesIn(testSocketTypes(false /* hasPreconnected*/)), + testing::Values(RpcCertificateFormat::PEM, RpcCertificateFormat::DER), + testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER)), + RpcTransportTlsKeyTest::PrintParamInfo); + } // namespace android int main(int argc, char** argv) { |