diff options
| -rw-r--r-- | libs/vr/libpdx_uds/client_channel_factory.cpp | 15 | ||||
| -rw-r--r-- | libs/vr/libpdx_uds/private/uds/ipc_helper.h | 9 | ||||
| -rw-r--r-- | libs/vr/libpdx_uds/private/uds/service_endpoint.h | 2 | ||||
| -rw-r--r-- | libs/vr/libpdx_uds/service_endpoint.cpp | 84 |
4 files changed, 77 insertions, 33 deletions
diff --git a/libs/vr/libpdx_uds/client_channel_factory.cpp b/libs/vr/libpdx_uds/client_channel_factory.cpp index 850c6d31ad..433f459769 100644 --- a/libs/vr/libpdx_uds/client_channel_factory.cpp +++ b/libs/vr/libpdx_uds/client_channel_factory.cpp @@ -60,7 +60,7 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect( bool connected = socket_.IsValid(); if (!connected) { - socket_.Reset(socket(AF_UNIX, SOCK_STREAM, 0)); + socket_.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)); LOG_ALWAYS_FATAL_IF( endpoint_path_.empty(), "ClientChannelFactory::Connect: unspecified socket path"); @@ -123,6 +123,15 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect( connected = true; ALOGD("ClientChannelFactory: Connected successfully to %s...", remote.sun_path); + ChannelConnectionInfo<LocalHandle> connection_info; + status = ReceiveData(socket_.Borrow(), &connection_info); + if (!status) + return status.error_status(); + socket_ = std::move(connection_info.channel_fd); + if (!socket_) { + ALOGE("ClientChannelFactory::Connect: Failed to obtain channel socket"); + return ErrorStatus(EIO); + } } if (use_timeout) now = steady_clock::now(); @@ -132,11 +141,11 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect( InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false); status = SendData(socket_.Borrow(), request); if (!status) - return ErrorStatus(status.error()); + return status.error_status(); ResponseHeader<LocalHandle> response; status = ReceiveData(socket_.Borrow(), &response); if (!status) - return ErrorStatus(status.error()); + return status.error_status(); int ref = response.ret_code; if (ref < 0 || static_cast<size_t>(ref) > response.file_descriptors.size()) return ErrorStatus(EIO); diff --git a/libs/vr/libpdx_uds/private/uds/ipc_helper.h b/libs/vr/libpdx_uds/private/uds/ipc_helper.h index 5b7e5ffd57..bde16d3d31 100644 --- a/libs/vr/libpdx_uds/private/uds/ipc_helper.h +++ b/libs/vr/libpdx_uds/private/uds/ipc_helper.h @@ -116,6 +116,15 @@ class ChannelInfo { }; template <typename FileHandleType> +class ChannelConnectionInfo { + public: + FileHandleType channel_fd; + + private: + PDX_SERIALIZABLE_MEMBERS(ChannelConnectionInfo, channel_fd); +}; + +template <typename FileHandleType> class RequestHeader { public: int32_t op{0}; diff --git a/libs/vr/libpdx_uds/private/uds/service_endpoint.h b/libs/vr/libpdx_uds/private/uds/service_endpoint.h index eb87827939..368891ce05 100644 --- a/libs/vr/libpdx_uds/private/uds/service_endpoint.h +++ b/libs/vr/libpdx_uds/private/uds/service_endpoint.h @@ -142,6 +142,8 @@ class Endpoint : public pdx::Endpoint { BorrowedHandle GetChannelSocketFd(int32_t channel_id); BorrowedHandle GetChannelEventFd(int32_t channel_id); int32_t GetChannelId(const BorrowedHandle& channel_fd); + Status<void> CreateChannelSocketPair(LocalHandle* local_socket, + LocalHandle* remote_socket); std::string endpoint_path_; bool is_blocking_; diff --git a/libs/vr/libpdx_uds/service_endpoint.cpp b/libs/vr/libpdx_uds/service_endpoint.cpp index 6c92259b16..d96eeff230 100644 --- a/libs/vr/libpdx_uds/service_endpoint.cpp +++ b/libs/vr/libpdx_uds/service_endpoint.cpp @@ -214,30 +214,42 @@ Status<void> Endpoint::AcceptConnection(Message* message) { sockaddr_un remote; socklen_t addrlen = sizeof(remote); - LocalHandle channel_fd{accept4(socket_fd_.Get(), - reinterpret_cast<sockaddr*>(&remote), &addrlen, - SOCK_CLOEXEC)}; - if (!channel_fd) { + LocalHandle connection_fd{accept4(socket_fd_.Get(), + reinterpret_cast<sockaddr*>(&remote), + &addrlen, SOCK_CLOEXEC)}; + if (!connection_fd) { ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s", strerror(errno)); return ErrorStatus(errno); } - int optval = 1; - if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval, - sizeof(optval)) == -1) { - ALOGE( - "Endpoint::AcceptConnection: Failed to enable the receiving of the " - "credentials for channel %d: %s", - channel_fd.Get(), strerror(errno)); - return ErrorStatus(errno); + LocalHandle local_socket; + LocalHandle remote_socket; + auto status = CreateChannelSocketPair(&local_socket, &remote_socket); + if (!status) + return status; + + // Borrow the local channel handle before we move it into OnNewChannel(). + BorrowedHandle channel_handle = local_socket.Borrow(); + status = OnNewChannel(std::move(local_socket)); + if (!status) + return status; + + // Send the channel socket fd to the client. + ChannelConnectionInfo<LocalHandle> connection_info; + connection_info.channel_fd = std::move(remote_socket); + status = SendData(connection_fd.Borrow(), connection_info); + + if (status) { + // Get the CHANNEL_OPEN message from client over the channel socket. + status = ReceiveMessageForChannel(channel_handle, message); + } else { + CloseChannel(GetChannelId(channel_handle)); } - // Borrow the channel handle before we pass (move) it into OnNewChannel(). - BorrowedHandle borrowed_channel_handle = channel_fd.Borrow(); - auto status = OnNewChannel(std::move(channel_fd)); - if (status) - status = ReceiveMessageForChannel(borrowed_channel_handle, message); + // Don't need the connection socket anymore. Further communication should + // happen over the channel socket. + shutdown(connection_fd.Get(), SHUT_WR); return status; } @@ -349,29 +361,41 @@ Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask, return ErrorStatus{EINVAL}; } -Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message, - int /*flags*/, - Channel* channel, - int* channel_id) { +Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket, + LocalHandle* remote_socket) { + Status<void> status; int channel_pair[2] = {}; if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) { - ALOGE("Endpoint::PushChannel: Failed to create a socket pair: %s", + ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s", strerror(errno)); - return ErrorStatus(errno); + status.SetError(errno); + return status; } - LocalHandle local_socket{channel_pair[0]}; - LocalHandle remote_socket{channel_pair[1]}; + local_socket->Reset(channel_pair[0]); + remote_socket->Reset(channel_pair[1]); int optval = 1; - if (setsockopt(local_socket.Get(), SOL_SOCKET, SO_PASSCRED, &optval, + if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval, sizeof(optval)) == -1) { ALOGE( - "Endpoint::PushChannel: Failed to enable the receiving of the " - "credentials for channel %d: %s", - local_socket.Get(), strerror(errno)); - return ErrorStatus(errno); + "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of " + "the credentials for channel %d: %s", + local_socket->Get(), strerror(errno)); + status.SetError(errno); } + return status; +} + +Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message, + int /*flags*/, + Channel* channel, + int* channel_id) { + LocalHandle local_socket; + LocalHandle remote_socket; + auto status = CreateChannelSocketPair(&local_socket, &remote_socket); + if (!status) + return status.error_status(); std::lock_guard<std::mutex> autolock(channel_mutex_); auto channel_data = OnNewChannelLocked(std::move(local_socket), channel); |