| /* |
| * Copyright (C) 2020 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "adb/pairing/pairing_server.h" |
| |
| #include <sys/epoll.h> |
| #include <sys/eventfd.h> |
| |
| #include <atomic> |
| #include <deque> |
| #include <iomanip> |
| #include <mutex> |
| #include <sstream> |
| #include <thread> |
| #include <tuple> |
| #include <unordered_map> |
| #include <variant> |
| #include <vector> |
| |
| #include <adb/crypto/rsa_2048_key.h> |
| #include <adb/crypto/x509_generator.h> |
| #include <adb/pairing/pairing_connection.h> |
| #include <android-base/logging.h> |
| #include <android-base/parsenetaddress.h> |
| #include <android-base/thread_annotations.h> |
| #include <android-base/unique_fd.h> |
| #include <cutils/sockets.h> |
| |
| #include "internal/constants.h" |
| |
| using android::base::ScopedLockAssertion; |
| using android::base::unique_fd; |
| using namespace adb::crypto; |
| using namespace adb::pairing; |
| |
| // The implementation has two background threads running: one to handle and |
| // accept any new pairing connection requests (socket accept), and the other to |
| // handle connection events (connection started, connection finished). |
| struct PairingServerCtx { |
| public: |
| using Data = std::vector<uint8_t>; |
| |
| virtual ~PairingServerCtx(); |
| |
| // All parameters must be non-empty. |
| explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, |
| const Data& priv_key, uint16_t port); |
| |
| // Starts the pairing server. This call is non-blocking. Upon completion, |
| // if the pairing was successful, then |cb| will be called with the PublicKeyHeader |
| // containing the info of the trusted peer. Otherwise, |cb| will be |
| // called with an empty value. Start can only be called once in the lifetime |
| // of this object. |
| // |
| // Returns the port number if PairingServerCtx was successfully started. Otherwise, |
| // returns 0. |
| uint16_t Start(pairing_server_result_cb cb, void* opaque); |
| |
| private: |
| // Setup the server socket to accept incoming connections. Returns the |
| // server port number (> 0 on success). |
| uint16_t SetupServer(); |
| // Force stop the server thread. |
| void StopServer(); |
| |
| // handles a new pairing client connection |
| bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_); |
| |
| // ======== connection events thread ============= |
| std::mutex conn_mutex_; |
| std::condition_variable conn_cv_; |
| |
| using FdVal = int; |
| struct ConnectionDeleter { |
| void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); } |
| }; |
| using ConnectionPtr = std::unique_ptr<PairingConnectionCtx, ConnectionDeleter>; |
| static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, |
| const Data& cert, const Data& priv_key); |
| using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>; |
| // <fd, PeerInfo.type, PeerInfo.data> |
| using ConnectionFinishedEvent = std::tuple<FdVal, uint8_t, std::optional<std::string>>; |
| using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>; |
| // Queue for connections to write into. We have a separate queue to read |
| // from, in order to minimize the time the server thread is blocked. |
| std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_); |
| std::deque<ConnectionEvent> conn_read_queue_; |
| // Map of fds to their PairingConnections currently running. |
| std::unordered_map<FdVal, ConnectionPtr> connections_; |
| |
| // Two threads launched when starting the pairing server: |
| // 1) A server thread that waits for incoming client connections, and |
| // 2) A connection events thread that synchonizes events from all of the |
| // clients, since each PairingConnection is running in it's own thread. |
| void StartConnectionEventsThread(); |
| void StartServerThread(); |
| |
| static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque); |
| |
| std::thread conn_events_thread_; |
| void ConnectionEventsWorker(); |
| std::thread server_thread_; |
| void ServerWorker(); |
| bool is_terminate_ GUARDED_BY(conn_mutex_) = false; |
| |
| enum class State { |
| Ready, |
| Running, |
| Stopped, |
| }; |
| State state_ = State::Ready; |
| Data pswd_; |
| PeerInfo peer_info_; |
| Data cert_; |
| Data priv_key_; |
| uint16_t port_; |
| |
| pairing_server_result_cb cb_; |
| void* opaque_ = nullptr; |
| bool got_valid_pairing_ = false; |
| |
| static const int kEpollConstSocket = 0; |
| // Used to break the server thread from epoll_wait |
| static const int kEpollConstEventFd = 1; |
| unique_fd epoll_fd_; |
| unique_fd server_fd_; |
| unique_fd event_fd_; |
| }; // PairingServerCtx |
| |
| // static |
| PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd, |
| const PeerInfo& peer_info, |
| const Data& cert, |
| const Data& priv_key) { |
| return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info, |
| cert.data(), cert.size(), priv_key.data(), |
| priv_key.size())); |
| } |
| |
| PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, |
| const Data& priv_key, uint16_t port) |
| : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { |
| CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); |
| } |
| |
| PairingServerCtx::~PairingServerCtx() { |
| // Since these connections have references to us, let's make sure they |
| // destruct before us. |
| if (server_thread_.joinable()) { |
| StopServer(); |
| server_thread_.join(); |
| } |
| |
| { |
| std::lock_guard<std::mutex> lock(conn_mutex_); |
| is_terminate_ = true; |
| } |
| conn_cv_.notify_one(); |
| if (conn_events_thread_.joinable()) { |
| conn_events_thread_.join(); |
| } |
| |
| // Notify the cb_ if it hasn't already. |
| if (!got_valid_pairing_ && cb_ != nullptr) { |
| cb_(nullptr, opaque_); |
| } |
| } |
| |
| uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) { |
| cb_ = cb; |
| opaque_ = opaque; |
| |
| if (state_ != State::Ready) { |
| LOG(ERROR) << "PairingServerCtx already running or stopped"; |
| return 0; |
| } |
| |
| port_ = SetupServer(); |
| if (port_ == 0) { |
| LOG(ERROR) << "Unable to start PairingServer"; |
| state_ = State::Stopped; |
| return 0; |
| } |
| LOG(INFO) << "Pairing server started on port " << port_; |
| |
| state_ = State::Running; |
| return port_; |
| } |
| |
| void PairingServerCtx::StopServer() { |
| if (event_fd_.get() == -1) { |
| return; |
| } |
| uint64_t value = 1; |
| ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); |
| if (rc == -1) { |
| // This can happen if the server didn't start. |
| PLOG(ERROR) << "write to eventfd failed"; |
| } else if (rc != sizeof(value)) { |
| LOG(FATAL) << "write to event returned short (" << rc << ")"; |
| } |
| } |
| |
| uint16_t PairingServerCtx::SetupServer() { |
| epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); |
| if (epoll_fd_ == -1) { |
| PLOG(ERROR) << "failed to create epoll fd"; |
| return 0; |
| } |
| |
| event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); |
| if (event_fd_ == -1) { |
| PLOG(ERROR) << "failed to create eventfd"; |
| return 0; |
| } |
| |
| server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); |
| if (server_fd_.get() == -1) { |
| PLOG(ERROR) << "Failed to start pairing connection server"; |
| return 0; |
| } |
| |
| StartConnectionEventsThread(); |
| StartServerThread(); |
| int port = socket_get_local_port(server_fd_.get()); |
| return (port <= 0 ? 0 : port); |
| } |
| |
| void PairingServerCtx::StartServerThread() { |
| server_thread_ = std::thread([this]() { ServerWorker(); }); |
| } |
| |
| void PairingServerCtx::StartConnectionEventsThread() { |
| conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); }); |
| } |
| |
| void PairingServerCtx::ServerWorker() { |
| { |
| struct epoll_event event; |
| event.events = EPOLLIN; |
| event.data.u64 = kEpollConstSocket; |
| CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); |
| } |
| |
| { |
| struct epoll_event event; |
| event.events = EPOLLIN; |
| event.data.u64 = kEpollConstEventFd; |
| CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); |
| } |
| |
| while (true) { |
| struct epoll_event events[2]; |
| int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); |
| if (rc == -1) { |
| PLOG(ERROR) << "epoll_wait failed"; |
| return; |
| } else if (rc == 0) { |
| LOG(ERROR) << "epoll_wait returned 0"; |
| return; |
| } |
| |
| for (int i = 0; i < rc; ++i) { |
| struct epoll_event& event = events[i]; |
| switch (event.data.u64) { |
| case kEpollConstSocket: |
| HandleNewClientConnection(server_fd_.get()); |
| break; |
| case kEpollConstEventFd: |
| uint64_t dummy; |
| int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); |
| if (rc != sizeof(dummy)) { |
| PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; |
| } |
| return; |
| } |
| } |
| } |
| } |
| |
| // static |
| void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) { |
| auto* p = reinterpret_cast<PairingServerCtx*>(opaque); |
| |
| ConnectionFinishedEvent event; |
| if (peer_info != nullptr) { |
| if (peer_info->type == ADB_RSA_PUB_KEY) { |
| event = std::make_tuple(fd, peer_info->type, |
| std::string(reinterpret_cast<const char*>(peer_info->data))); |
| } else { |
| LOG(WARNING) << "Ignoring successful pairing because of unknown " |
| << "PeerInfo type=" << peer_info->type; |
| } |
| } else { |
| event = std::make_tuple(fd, 0, std::nullopt); |
| } |
| { |
| std::lock_guard<std::mutex> lock(p->conn_mutex_); |
| p->conn_write_queue_.push_back(std::move(event)); |
| } |
| p->conn_cv_.notify_one(); |
| } |
| |
| void PairingServerCtx::ConnectionEventsWorker() { |
| uint8_t num_tries = 0; |
| for (;;) { |
| // Transfer the write queue to the read queue. |
| { |
| std::unique_lock<std::mutex> lock(conn_mutex_); |
| ScopedLockAssertion assume_locked(conn_mutex_); |
| |
| if (is_terminate_) { |
| // We check |is_terminate_| twice because condition_variable's |
| // notify() only wakes up a thread if it is in the wait state |
| // prior to notify(). Furthermore, we aren't holding the mutex |
| // when processing the events in |conn_read_queue_|. |
| return; |
| } |
| if (conn_write_queue_.empty()) { |
| // We need to wait for new events, or the termination signal. |
| conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { |
| return (is_terminate_ || !conn_write_queue_.empty()); |
| }); |
| } |
| if (is_terminate_) { |
| // We're done. |
| return; |
| } |
| // Move all events into the read queue. |
| conn_read_queue_ = std::move(conn_write_queue_); |
| conn_write_queue_.clear(); |
| } |
| |
| // Process all events in the read queue. |
| while (conn_read_queue_.size() > 0) { |
| auto& event = conn_read_queue_.front(); |
| if (auto* p = std::get_if<NewConnectionEvent>(&event)) { |
| // Ignore if we are already at the max number of connections |
| if (connections_.size() >= internal::kMaxConnections) { |
| conn_read_queue_.pop_front(); |
| continue; |
| } |
| auto [ufd, connection] = std::move(*p); |
| int fd = ufd.release(); |
| bool started = pairing_connection_start(connection.get(), fd, |
| PairingConnectionCallback, this); |
| if (!started) { |
| LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; |
| ufd.reset(fd); |
| } else { |
| connections_[fd] = std::move(connection); |
| } |
| } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) { |
| auto [fd, info_type, public_key] = std::move(*p); |
| if (public_key.has_value() && !public_key->empty()) { |
| // Valid pairing. Let's shutdown the server and close any |
| // pairing connections in progress. |
| StopServer(); |
| connections_.clear(); |
| |
| PeerInfo info = {}; |
| info.type = info_type; |
| strncpy(reinterpret_cast<char*>(info.data), public_key->data(), |
| public_key->size()); |
| |
| cb_(&info, opaque_); |
| |
| got_valid_pairing_ = true; |
| return; |
| } |
| // Invalid pairing. Close the invalid connection. |
| if (connections_.find(fd) != connections_.end()) { |
| connections_.erase(fd); |
| } |
| |
| if (++num_tries >= internal::kMaxPairingAttempts) { |
| cb_(nullptr, opaque_); |
| // To prevent the destructor from calling it again. |
| cb_ = nullptr; |
| return; |
| } |
| } |
| conn_read_queue_.pop_front(); |
| } |
| } |
| } |
| |
| bool PairingServerCtx::HandleNewClientConnection(int fd) { |
| unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); |
| if (ufd == -1) { |
| PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; |
| return false; |
| } |
| auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_); |
| if (connection == nullptr) { |
| LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; |
| return false; |
| } |
| // send the new connection to the connection thread for further processing |
| NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); |
| { |
| std::lock_guard<std::mutex> lock(conn_mutex_); |
| conn_write_queue_.push_back(std::move(event)); |
| } |
| conn_cv_.notify_one(); |
| |
| return true; |
| } |
| |
| uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) { |
| return ctx->Start(cb, opaque); |
| } |
| |
| PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len, |
| const PeerInfo* peer_info, const uint8_t* x509_cert_pem, |
| size_t x509_size, const uint8_t* priv_key_pem, |
| size_t priv_size, uint16_t port) { |
| CHECK(pswd); |
| CHECK_GT(pswd_len, 0U); |
| CHECK(x509_cert_pem); |
| CHECK_GT(x509_size, 0U); |
| CHECK(priv_key_pem); |
| CHECK_GT(priv_size, 0U); |
| CHECK(peer_info); |
| std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len); |
| std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); |
| std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size); |
| return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port); |
| } |
| |
| PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len, |
| const PeerInfo* peer_info, uint16_t port) { |
| auto rsa_2048 = CreateRSA2048Key(); |
| auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey()); |
| std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey()); |
| std::string cert_pem = X509ToPEMString(x509_cert.get()); |
| |
| return pairing_server_new(pswd, pswd_len, peer_info, |
| reinterpret_cast<const uint8_t*>(cert_pem.data()), cert_pem.size(), |
| reinterpret_cast<const uint8_t*>(pkey_pem.data()), pkey_pem.size(), |
| port); |
| } |
| |
| void pairing_server_destroy(PairingServerCtx* ctx) { |
| CHECK(ctx); |
| delete ctx; |
| } |