diff options
author | 2023-02-14 19:44:40 +0000 | |
---|---|---|
committer | 2023-02-24 22:17:47 +0000 | |
commit | 8cab37dc2c701213e233ec42cac392b29c9d8ce9 (patch) | |
tree | ffde2c59f0b7e13872ac9d443781c23a04b5a8dc | |
parent | e59a4b6e4208b43ab5b85e1f8319da4ddc471b93 (diff) |
[Private GATT] Add support for GATT write requests
Bug: 255880936
Test: unit
Change-Id: Iec10e55b720729bd88e8a3889d9f47150fb693ee
-rw-r--r-- | system/rust/src/gatt/callbacks.rs | 24 | ||||
-rw-r--r-- | system/rust/src/gatt/callbacks/callback_transaction_manager.rs | 23 | ||||
-rw-r--r-- | system/rust/src/gatt/ffi.rs | 40 | ||||
-rw-r--r-- | system/rust/src/gatt/ffi/gatt_shim.cc | 22 | ||||
-rw-r--r-- | system/rust/src/gatt/ffi/gatt_shim.h | 5 | ||||
-rw-r--r-- | system/rust/src/gatt/mocks/mock_callbacks.rs | 42 | ||||
-rw-r--r-- | system/rust/src/gatt/mocks/mock_datastore.rs | 23 | ||||
-rw-r--r-- | system/rust/src/gatt/server/att_database.rs | 19 | ||||
-rw-r--r-- | system/rust/src/gatt/server/att_server_bearer.rs | 17 | ||||
-rw-r--r-- | system/rust/src/gatt/server/gatt_database.rs | 149 | ||||
-rw-r--r-- | system/rust/src/gatt/server/test/test_att_db.rs | 31 | ||||
-rw-r--r-- | system/rust/src/gatt/server/transaction_handler.rs | 7 | ||||
-rw-r--r-- | system/rust/src/gatt/server/transactions.rs | 1 | ||||
-rw-r--r-- | system/rust/src/gatt/server/transactions/write_request.rs | 100 | ||||
-rw-r--r-- | system/rust/src/utils.rs | 1 | ||||
-rw-r--r-- | system/rust/src/utils/task.rs | 10 | ||||
-rw-r--r-- | system/rust/tests/gatt_callbacks_test.rs | 61 | ||||
-rw-r--r-- | system/rust/tests/gatt_server_test.rs | 53 |
18 files changed, 594 insertions, 34 deletions
diff --git a/system/rust/src/gatt/callbacks.rs b/system/rust/src/gatt/callbacks.rs index 5bed2b511e..c8e9ce28b4 100644 --- a/system/rust/src/gatt/callbacks.rs +++ b/system/rust/src/gatt/callbacks.rs @@ -8,7 +8,7 @@ pub use callback_transaction_manager::{CallbackResponseError, CallbackTransactio use async_trait::async_trait; -use crate::packets::{AttAttributeDataChild, AttErrorCode}; +use crate::packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode}; use super::ids::{AttHandle, ConnectionId, TransactionId}; @@ -25,6 +25,20 @@ pub trait GattCallbacks { offset: u32, is_long: bool, ); + + /// Invoked when a client tries to write a characteristic. Expects a + /// response using bluetooth::gatt::send_response(); + #[allow(clippy::too_many_arguments)] // needed to match the C++ interface + fn on_server_write_characteristic( + &self, + conn_id: ConnectionId, + trans_id: TransactionId, + handle: AttHandle, + offset: u32, + need_response: bool, + is_prepare: bool, + value: AttAttributeDataView, + ); } /// This interface is an "async" version of the above, and is passed directly @@ -44,4 +58,12 @@ pub trait GattDatastore { conn_id: ConnectionId, handle: AttHandle, ) -> Result<AttAttributeDataChild, AttErrorCode>; + + /// Write data to a given characteristic on the specified connection. + async fn write_characteristic( + &self, + conn_id: ConnectionId, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode>; } diff --git a/system/rust/src/gatt/callbacks/callback_transaction_manager.rs b/system/rust/src/gatt/callbacks/callback_transaction_manager.rs index 65f34fd77c..0406a1e000 100644 --- a/system/rust/src/gatt/callbacks/callback_transaction_manager.rs +++ b/system/rust/src/gatt/callbacks/callback_transaction_manager.rs @@ -9,7 +9,7 @@ use crate::{ ids::{AttHandle, ConnectionId, TransactionId}, GattCallbacks, }, - packets::{AttAttributeDataChild, AttErrorCode}, + packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode}, }; use super::GattDatastore; @@ -138,4 +138,25 @@ impl GattDatastore for CallbackTransactionManager { Err(AttErrorCode::UNLIKELY_ERROR) } } + + async fn write_characteristic( + &self, + conn_id: ConnectionId, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode> { + let (trans_id, rx) = + self.pending_transactions.borrow_mut().start_new_transaction(conn_id)?; + + self.callbacks + .on_server_write_characteristic(conn_id, trans_id, handle, 0, true, false, data); + + if let Ok(value) = rx.await { + value.map(|_| ()) // the data passed back is irrelevant for write + // requests + } else { + warn!("sender side of {trans_id:?} dropped while handling request - most likely this response will not be sent over the air"); + Err(AttErrorCode::UNLIKELY_ERROR) + } + } } diff --git a/system/rust/src/gatt/ffi.rs b/system/rust/src/gatt/ffi.rs index dfc41e460a..d56c35599f 100644 --- a/system/rust/src/gatt/ffi.rs +++ b/system/rust/src/gatt/ffi.rs @@ -11,7 +11,10 @@ use log::{error, info, warn}; use crate::{ do_in_rust_thread, - packets::{AttAttributeDataChild, AttBuilder, AttErrorCode, Serializable, SerializeError}, + packets::{ + AttAttributeDataChild, AttAttributeDataView, AttBuilder, AttErrorCode, Serializable, + SerializeError, + }, }; use super::{ @@ -54,6 +57,20 @@ mod inner { offset: u32, is_long: bool, ); + + /// This callback is invoked when writing a characteristic - the client + /// must reply using SendResponse + #[cxx_name = "OnServerWriteCharacteristic"] + fn on_server_write_characteristic( + self: &GattServerCallbacks, + conn_id: u16, + trans_id: u32, + attr_handle: u16, + offset: u32, + need_response: bool, + is_prepare: bool, + value: &[u8], + ); } /// What action the arbiter should take in response to an incoming packet @@ -143,6 +160,27 @@ impl GattCallbacks for GattCallbacksImpl { .unwrap() .on_server_read_characteristic(conn_id.0, trans_id.0, handle.0, offset, is_long); } + + fn on_server_write_characteristic( + &self, + conn_id: ConnectionId, + trans_id: TransactionId, + handle: AttHandle, + offset: u32, + need_response: bool, + is_prepare: bool, + value: AttAttributeDataView, + ) { + self.0.as_ref().unwrap().on_server_write_characteristic( + conn_id.0, + trans_id.0, + handle.0, + offset, + need_response, + is_prepare, + &value.get_raw_payload().collect::<Vec<_>>(), + ); + } } /// Implementation of AttTransport wrapping the corresponding C++ method diff --git a/system/rust/src/gatt/ffi/gatt_shim.cc b/system/rust/src/gatt/ffi/gatt_shim.cc index 59d4070d79..58e68f47a3 100644 --- a/system/rust/src/gatt/ffi/gatt_shim.cc +++ b/system/rust/src/gatt/ffi/gatt_shim.cc @@ -69,5 +69,27 @@ void GattServerCallbacks::OnServerReadCharacteristic(uint16_t conn_id, addr.value(), attr_handle, offset, is_long)); } +void GattServerCallbacks::OnServerWriteCharacteristic( + uint16_t conn_id, uint32_t trans_id, uint16_t attr_handle, uint32_t offset, + bool need_response, bool is_prepare, + ::rust::Slice<const uint8_t> value) const { + auto addr = AddressOfConnection(conn_id); + if (!addr.has_value()) { + LOG_WARN( + "Dropping server write characteristic since connection %d not found", + conn_id); + return; + } + + auto buf = new uint8_t[value.size()]; + std::copy(value.begin(), value.end(), buf); + + do_in_jni_thread( + FROM_HERE, + base::Bind(callbacks.request_write_characteristic_cb, conn_id, trans_id, + addr.value(), attr_handle, offset, need_response, is_prepare, + base::Owned(buf), value.size())); +} + } // namespace gatt } // namespace bluetooth diff --git a/system/rust/src/gatt/ffi/gatt_shim.h b/system/rust/src/gatt/ffi/gatt_shim.h index c680058406..db5e8a27e9 100644 --- a/system/rust/src/gatt/ffi/gatt_shim.h +++ b/system/rust/src/gatt/ffi/gatt_shim.h @@ -34,6 +34,11 @@ class GattServerCallbacks { uint16_t attr_handle, uint32_t offset, bool is_long) const; + void OnServerWriteCharacteristic(uint16_t conn_id, uint32_t trans_id, + uint16_t attr_handle, uint32_t offset, + bool need_response, bool is_prepare, + ::rust::Slice<const uint8_t> value) const; + private: const btgatt_server_callbacks_t& callbacks; }; diff --git a/system/rust/src/gatt/mocks/mock_callbacks.rs b/system/rust/src/gatt/mocks/mock_callbacks.rs index 9df2d35b4e..bea05dac05 100644 --- a/system/rust/src/gatt/mocks/mock_callbacks.rs +++ b/system/rust/src/gatt/mocks/mock_callbacks.rs @@ -1,8 +1,11 @@ //! Mocked implementation of GattCallbacks for use in test -use crate::gatt::{ - ids::{AttHandle, ConnectionId, TransactionId}, - GattCallbacks, +use crate::{ + gatt::{ + ids::{AttHandle, ConnectionId, TransactionId}, + GattCallbacks, + }, + packets::{AttAttributeDataView, OwnedAttAttributeDataView, Packet}, }; use tokio::sync::mpsc::{self, unbounded_channel, UnboundedReceiver}; @@ -22,6 +25,16 @@ impl MockCallbacks { pub enum MockCallbackEvents { /// GattCallbacks#on_server_read_characteristic invoked OnServerReadCharacteristic(ConnectionId, TransactionId, AttHandle, u32, bool), + /// GattCallbacks#on_server_write_characteristic invoked + OnServerWriteCharacteristic( + ConnectionId, + TransactionId, + AttHandle, + u32, + bool, + bool, + OwnedAttAttributeDataView, + ), } impl GattCallbacks for MockCallbacks { @@ -39,4 +52,27 @@ impl GattCallbacks for MockCallbacks { )) .unwrap(); } + + fn on_server_write_characteristic( + &self, + conn_id: ConnectionId, + trans_id: TransactionId, + handle: AttHandle, + offset: u32, + need_response: bool, + is_prepare: bool, + value: AttAttributeDataView, + ) { + self.0 + .send(MockCallbackEvents::OnServerWriteCharacteristic( + conn_id, + trans_id, + handle, + offset, + need_response, + is_prepare, + value.to_owned_packet(), + )) + .unwrap(); + } } diff --git a/system/rust/src/gatt/mocks/mock_datastore.rs b/system/rust/src/gatt/mocks/mock_datastore.rs index d1e25a0909..7e7eac2465 100644 --- a/system/rust/src/gatt/mocks/mock_datastore.rs +++ b/system/rust/src/gatt/mocks/mock_datastore.rs @@ -5,7 +5,10 @@ use crate::{ callbacks::GattDatastore, ids::{AttHandle, ConnectionId}, }, - packets::{AttAttributeDataChild, AttErrorCode, OwnedAttAttributeDataView}, + packets::{ + AttAttributeDataChild, AttAttributeDataView, AttErrorCode, OwnedAttAttributeDataView, + Packet, + }, }; use async_trait::async_trait; use log::info; @@ -70,4 +73,22 @@ impl GattDatastore for MockDatastore { info!("sending {resp:?} down from upper tester"); resp } + + async fn write_characteristic( + &self, + conn_id: ConnectionId, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode> { + let (tx, rx) = oneshot::channel(); + self.0 + .send(MockDatastoreEvents::WriteCharacteristic( + conn_id, + handle, + data.to_owned_packet(), + tx, + )) + .unwrap(); + rx.await.unwrap() + } } diff --git a/system/rust/src/gatt/server/att_database.rs b/system/rust/src/gatt/server/att_database.rs index bc40552b74..d60d8cc035 100644 --- a/system/rust/src/gatt/server/att_database.rs +++ b/system/rust/src/gatt/server/att_database.rs @@ -3,7 +3,9 @@ use async_trait::async_trait; use crate::{ core::uuid::Uuid, gatt::ids::AttHandle, - packets::{AttAttributeDataChild, AttErrorCode, AttHandleBuilder, AttHandleView}, + packets::{ + AttAttributeDataChild, AttAttributeDataView, AttErrorCode, AttHandleBuilder, AttHandleView, + }, }; // UUIDs from Bluetooth Assigned Numbers Sec 3.6 @@ -54,6 +56,13 @@ pub trait AttDatabase { handle: AttHandle, ) -> Result<AttAttributeDataChild, AttErrorCode>; + /// Write to an attribute by handle + async fn write_attribute( + &self, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode>; + /// List all the attributes in this database. /// /// Expected to return them in sorted order. @@ -94,6 +103,14 @@ impl AttDatabase for SnapshottedAttDatabase<'_> { self.backing.read_attribute(handle).await } + async fn write_attribute( + &self, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode> { + self.backing.write_attribute(handle, data).await + } + fn list_attributes(&self) -> Vec<AttAttribute> { self.attributes.clone() } diff --git a/system/rust/src/gatt/server/att_server_bearer.rs b/system/rust/src/gatt/server/att_server_bearer.rs index 29f53d28c3..831b7c8e06 100644 --- a/system/rust/src/gatt/server/att_server_bearer.rs +++ b/system/rust/src/gatt/server/att_server_bearer.rs @@ -103,11 +103,7 @@ impl<T: AttDatabase + 'static> AttServerBearer<T> { #[cfg(test)] mod test { - use tokio::{ - runtime::Runtime, - sync::mpsc::{error::TryRecvError, unbounded_channel, UnboundedReceiver}, - task::LocalSet, - }; + use tokio::sync::mpsc::{error::TryRecvError, unbounded_channel, UnboundedReceiver}; use super::*; @@ -128,7 +124,10 @@ mod test { packets::{ AttAttributeDataChild, AttOpcode, AttReadRequestBuilder, AttReadResponseBuilder, }, - utils::packet::{build_att_data, build_att_view_or_crash}, + utils::{ + packet::{build_att_data, build_att_view_or_crash}, + task::block_on_locally, + }, }; const VALID_HANDLE: AttHandle = AttHandle(3); @@ -156,7 +155,7 @@ mod test { #[test] fn test_single_transaction() { - LocalSet::new().block_on(&Runtime::new().unwrap(), async { + block_on_locally(async { let (conn, mut rx) = open_connection(); conn.handle_packet( build_att_view_or_crash(AttReadRequestBuilder { @@ -171,7 +170,7 @@ mod test { #[test] fn test_sequential_transactions() { - LocalSet::new().block_on(&Runtime::new().unwrap(), async { + block_on_locally(async { let (conn, mut rx) = open_connection(); conn.handle_packet( build_att_view_or_crash(AttReadRequestBuilder { @@ -229,7 +228,7 @@ mod test { // act: send two read requests before replying to either read // first request - LocalSet::new().block_on(&Runtime::new().unwrap(), async { + block_on_locally(async { let req1 = build_att_view_or_crash(AttReadRequestBuilder { attribute_handle: VALID_HANDLE.into(), }); diff --git a/system/rust/src/gatt/server/gatt_database.rs b/system/rust/src/gatt/server/gatt_database.rs index b110efb3c4..f52496bd63 100644 --- a/system/rust/src/gatt/server/gatt_database.rs +++ b/system/rust/src/gatt/server/gatt_database.rs @@ -14,8 +14,9 @@ use crate::{ ids::{AttHandle, ConnectionId}, }, packets::{ - AttAttributeDataChild, AttCharacteristicPropertiesBuilder, AttErrorCode, - GattCharacteristicDeclarationValueBuilder, GattServiceDeclarationValueBuilder, UuidBuilder, + AttAttributeDataChild, AttAttributeDataView, AttCharacteristicPropertiesBuilder, + AttErrorCode, GattCharacteristicDeclarationValueBuilder, + GattServiceDeclarationValueBuilder, UuidBuilder, }, }; @@ -258,6 +259,25 @@ where self.gatt_db.datastore.read_characteristic(self.conn_id, handle).await } + async fn write_attribute( + &self, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode> { + { + // block needed to drop the RefCell before the async point + let services = self.gatt_db.schema.borrow(); + let Some(attr) = services.attributes.get(&handle) else { + return Err(AttErrorCode::INVALID_HANDLE); + }; + if !attr.attribute.permissions.writable { + return Err(AttErrorCode::WRITE_NOT_PERMITTED); + } + } + + self.gatt_db.datastore.write_characteristic(self.conn_id, handle, data).await + } + fn list_attributes(&self) -> Vec<AttAttribute> { self.gatt_db.schema.borrow().attributes.values().map(|attr| attr.attribute).collect() } @@ -265,9 +285,16 @@ where #[cfg(test)] mod test { - use tokio::join; + use tokio::{join, task::spawn_local}; - use crate::gatt::mocks::mock_datastore::{MockDatastore, MockDatastoreEvents}; + use crate::{ + gatt::mocks::mock_datastore::{MockDatastore, MockDatastoreEvents}, + packets::Packet, + utils::{ + packet::{build_att_data, build_view_or_crash}, + task::block_on_locally, + }, + }; use super::*; @@ -578,4 +605,118 @@ mod test { let read_result = tokio_test::block_on(att_db.read_attribute(SERVICE_HANDLE)); assert!(read_result.is_err()); } + + #[test] + fn test_write_single_characteristic_callback_invoked() { + // arrange: create a database with a single characteristic + let (gatt_datastore, mut data_evts) = MockDatastore::new(); + let gatt_db = Rc::new(GattDatabase::new(gatt_datastore.into())); + gatt_db + .add_service_with_handles(GattServiceWithHandle { + handle: SERVICE_HANDLE, + type_: SERVICE_TYPE, + characteristics: vec![GattCharacteristicWithHandle { + handle: CHARACTERISTIC_VALUE_HANDLE, + type_: CHARACTERISTIC_TYPE, + permissions: AttPermissions { readable: false, writable: true }, + }], + }) + .unwrap(); + let att_db = gatt_db.get_att_database(CONN_ID); + let data = + build_view_or_crash(build_att_data(AttAttributeDataChild::RawData(Box::new([1, 2])))); + + // act: write to the database + let recv_data = block_on_locally(async { + // start write task + let cloned_data = data.view().to_owned_packet(); + spawn_local(async move { + att_db + .write_attribute(CHARACTERISTIC_VALUE_HANDLE, cloned_data.view()) + .await + .unwrap(); + }); + + let MockDatastoreEvents::WriteCharacteristic( + CONN_ID, + CHARACTERISTIC_VALUE_HANDLE, + recv_data, + _, + ) = data_evts.recv().await.unwrap() else { + unreachable!(); + }; + recv_data + }); + + // assert: the received value matches what we supplied + assert_eq!( + recv_data.view().get_raw_payload().collect::<Vec<_>>(), + data.view().get_raw_payload().collect::<Vec<_>>() + ); + } + + #[test] + fn test_write_single_characteristic_recv_response() { + // arrange: create a database with a single characteristic + let (gatt_datastore, mut data_evts) = MockDatastore::new(); + let gatt_db = Rc::new(GattDatabase::new(gatt_datastore.into())); + gatt_db + .add_service_with_handles(GattServiceWithHandle { + handle: SERVICE_HANDLE, + type_: SERVICE_TYPE, + characteristics: vec![GattCharacteristicWithHandle { + handle: CHARACTERISTIC_VALUE_HANDLE, + type_: CHARACTERISTIC_TYPE, + permissions: AttPermissions { readable: false, writable: true }, + }], + }) + .unwrap(); + let att_db = gatt_db.get_att_database(CONN_ID); + let data = + build_view_or_crash(build_att_data(AttAttributeDataChild::RawData(Box::new([1, 2])))); + + // act: write to the database + let res = tokio_test::block_on(async { + join!( + async { + let MockDatastoreEvents::WriteCharacteristic(_,_,_,reply) = data_evts.recv().await.unwrap() else { + unreachable!(); + }; + reply.send(Err(AttErrorCode::UNLIKELY_ERROR)).unwrap(); + }, + att_db.write_attribute(CHARACTERISTIC_VALUE_HANDLE, data.view()) + ) + .1 + }); + + // assert: the supplied value matches what the att datastore returned + assert_eq!(res, Err(AttErrorCode::UNLIKELY_ERROR)); + } + + #[test] + fn test_unwriteable_characteristic() { + let (gatt_datastore, _) = MockDatastore::new(); + let gatt_db = Rc::new(GattDatabase::new(gatt_datastore.into())); + gatt_db + .add_service_with_handles(GattServiceWithHandle { + handle: SERVICE_HANDLE, + type_: SERVICE_TYPE, + characteristics: vec![GattCharacteristicWithHandle { + handle: CHARACTERISTIC_VALUE_HANDLE, + type_: CHARACTERISTIC_TYPE, + permissions: AttPermissions::READONLY, + }], + }) + .unwrap(); + let data = + build_view_or_crash(build_att_data(AttAttributeDataChild::RawData(Box::new([1, 2])))); + + let characteristic_value = tokio_test::block_on( + gatt_db + .get_att_database(CONN_ID) + .write_attribute(CHARACTERISTIC_VALUE_HANDLE, data.view()), + ); + + assert_eq!(characteristic_value, Err(AttErrorCode::WRITE_NOT_PERMITTED)); + } } diff --git a/system/rust/src/gatt/server/test/test_att_db.rs b/system/rust/src/gatt/server/test/test_att_db.rs index c2581ddf23..849a8254dd 100644 --- a/system/rust/src/gatt/server/test/test_att_db.rs +++ b/system/rust/src/gatt/server/test/test_att_db.rs @@ -6,24 +6,23 @@ use crate::{ gatt_database::AttPermissions, }, }, - packets::{AttAttributeDataChild, AttErrorCode}, + packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode}, }; use async_trait::async_trait; use log::info; -use std::collections::BTreeMap; +use std::{cell::Cell, collections::BTreeMap}; pub struct TestAttDatabase { - attributes: BTreeMap<AttHandle, (AttAttribute, Vec<u8>)>, + attributes: BTreeMap<AttHandle, (AttAttribute, Cell<Vec<u8>>)>, } impl TestAttDatabase { - #[cfg(test)] pub fn new(attributes: Vec<(AttAttribute, Vec<u8>)>) -> Self { Self { attributes: attributes .into_iter() - .map(|(att, data)| (att.handle, (att, data))) + .map(|(att, data)| (att.handle, (att, Cell::new(data)))) .collect(), } } @@ -40,7 +39,27 @@ impl AttDatabase for TestAttDatabase { Some((AttAttribute { permissions: AttPermissions { readable: false, .. }, .. }, _)) => { Err(AttErrorCode::READ_NOT_PERMITTED) } - Some((_, data)) => Ok(AttAttributeDataChild::RawData(data.clone().into_boxed_slice())), + Some((_, data)) => { + let contents = data.take(); + data.set(contents.clone()); + Ok(AttAttributeDataChild::RawData(contents.into_boxed_slice())) + } + None => Err(AttErrorCode::INVALID_HANDLE), + } + } + async fn write_attribute( + &self, + handle: AttHandle, + data: AttAttributeDataView<'_>, + ) -> Result<(), AttErrorCode> { + match self.attributes.get(&handle) { + Some((AttAttribute { permissions: AttPermissions { writable: false, .. }, .. }, _)) => { + Err(AttErrorCode::WRITE_NOT_PERMITTED) + } + Some((_, data_cell)) => { + data_cell.replace(data.get_raw_payload().collect()); + Ok(()) + } None => Err(AttErrorCode::INVALID_HANDLE), } } diff --git a/system/rust/src/gatt/server/transaction_handler.rs b/system/rust/src/gatt/server/transaction_handler.rs index ccb51a9946..b56e70c63f 100644 --- a/system/rust/src/gatt/server/transaction_handler.rs +++ b/system/rust/src/gatt/server/transaction_handler.rs @@ -5,7 +5,8 @@ use crate::{ packets::{ AttChild, AttErrorCode, AttErrorResponseBuilder, AttFindByTypeValueRequestView, AttFindInformationRequestView, AttOpcode, AttReadByGroupTypeRequestView, - AttReadByTypeRequestView, AttReadRequestView, AttView, Packet, ParseError, + AttReadByTypeRequestView, AttReadRequestView, AttView, AttWriteRequestView, Packet, + ParseError, }, }; @@ -16,6 +17,7 @@ use super::{ find_information_request::handle_find_information_request, read_by_group_type_request::handle_read_by_group_type_request, read_by_type_request::handle_read_by_type_request, read_request::handle_read_request, + write_request::handle_write_request, }, }; @@ -85,6 +87,9 @@ impl<Db: AttDatabase> AttTransactionHandler<Db> { &snapshotted_db, ) .await), + AttOpcode::WRITE_REQUEST => { + Ok(handle_write_request(AttWriteRequestView::try_parse(packet)?, &self.db).await) + } _ => { warn!("Dropping unsupported opcode {:?}", packet.get_opcode()); Err(ParseError::InvalidEnumValue) diff --git a/system/rust/src/gatt/server/transactions.rs b/system/rust/src/gatt/server/transactions.rs index 92198c371e..e476bdc75b 100644 --- a/system/rust/src/gatt/server/transactions.rs +++ b/system/rust/src/gatt/server/transactions.rs @@ -4,3 +4,4 @@ mod helpers; pub mod read_by_group_type_request; pub mod read_by_type_request; pub mod read_request; +pub mod write_request; diff --git a/system/rust/src/gatt/server/transactions/write_request.rs b/system/rust/src/gatt/server/transactions/write_request.rs new file mode 100644 index 0000000000..163101fcd5 --- /dev/null +++ b/system/rust/src/gatt/server/transactions/write_request.rs @@ -0,0 +1,100 @@ +use crate::{ + gatt::server::att_database::AttDatabase, + packets::{ + AttChild, AttErrorResponseBuilder, AttOpcode, AttWriteRequestView, AttWriteResponseBuilder, + }, +}; + +pub async fn handle_write_request<T: AttDatabase>( + request: AttWriteRequestView<'_>, + db: &T, +) -> AttChild { + let handle = request.get_handle().into(); + match db.write_attribute(handle, request.get_value()).await { + Ok(()) => AttWriteResponseBuilder {}.into(), + Err(error_code) => AttErrorResponseBuilder { + opcode_in_error: AttOpcode::WRITE_REQUEST, + handle_in_error: handle.into(), + error_code, + } + .into(), + } +} + +#[cfg(test)] +mod test { + use super::*; + + use tokio_test::block_on; + + use crate::{ + core::uuid::Uuid, + gatt::{ + ids::AttHandle, + server::{ + att_database::{AttAttribute, AttDatabase}, + gatt_database::AttPermissions, + test::test_att_db::TestAttDatabase, + }, + }, + packets::{ + AttAttributeDataChild, AttChild, AttErrorCode, AttErrorResponseBuilder, + AttWriteRequestBuilder, AttWriteResponseBuilder, + }, + utils::packet::{build_att_data, build_view_or_crash}, + }; + + #[test] + fn test_successful_write() { + // arrange: db with one writable attribute + let db = TestAttDatabase::new(vec![( + AttAttribute { + handle: AttHandle(1), + type_: Uuid::new(0x1234), + permissions: AttPermissions { readable: true, writable: true }, + }, + vec![], + )]); + let data = AttAttributeDataChild::RawData([1, 2].into()); + + // act: write to the attribute + let att_view = build_view_or_crash(AttWriteRequestBuilder { + handle: AttHandle(1).into(), + value: build_att_data(data.clone()), + }); + let resp = block_on(handle_write_request(att_view.view(), &db)); + + // assert: that the write succeeded + assert_eq!(resp, AttChild::from(AttWriteResponseBuilder {})); + assert_eq!(block_on(db.read_attribute(AttHandle(1))).unwrap(), data); + } + + #[test] + fn test_failed_write() { + // arrange: db with no writable attributes + let db = TestAttDatabase::new(vec![( + AttAttribute { + handle: AttHandle(1), + type_: Uuid::new(0x1234), + permissions: AttPermissions { readable: true, writable: false }, + }, + vec![], + )]); + // act: write to the attribute + let att_view = build_view_or_crash(AttWriteRequestBuilder { + handle: AttHandle(1).into(), + value: build_att_data(AttAttributeDataChild::RawData([1, 2].into())), + }); + let resp = block_on(handle_write_request(att_view.view(), &db)); + + // assert: that the write failed + assert_eq!( + resp, + AttChild::from(AttErrorResponseBuilder { + opcode_in_error: AttOpcode::WRITE_REQUEST, + handle_in_error: AttHandle(1).into(), + error_code: AttErrorCode::WRITE_NOT_PERMITTED + }) + ); + } +} diff --git a/system/rust/src/utils.rs b/system/rust/src/utils.rs index bde0f821cd..bcc35dd44e 100644 --- a/system/rust/src/utils.rs +++ b/system/rust/src/utils.rs @@ -2,3 +2,4 @@ pub mod owned_handle; pub mod packet; +pub mod task; diff --git a/system/rust/src/utils/task.rs b/system/rust/src/utils/task.rs new file mode 100644 index 0000000000..5e240417ae --- /dev/null +++ b/system/rust/src/utils/task.rs @@ -0,0 +1,10 @@ +//! This module provides utilities relating to async tasks + +use std::future::Future; + +use tokio::{runtime::Builder, task::LocalSet}; + +/// Run the supplied future on a single-threaded runtime +pub fn block_on_locally<T>(f: impl Future<Output = T>) -> T { + LocalSet::new().block_on(&Builder::new_current_thread().build().unwrap(), f) +} diff --git a/system/rust/tests/gatt_callbacks_test.rs b/system/rust/tests/gatt_callbacks_test.rs index 18cabb2109..be4b556a19 100644 --- a/system/rust/tests/gatt_callbacks_test.rs +++ b/system/rust/tests/gatt_callbacks_test.rs @@ -8,7 +8,8 @@ use bluetooth_core::{ ids::{AttHandle, ConnectionId, ServerId, TransactionId, TransportIndex}, mocks::mock_callbacks::{MockCallbackEvents, MockCallbacks}, }, - packets::AttAttributeDataChild, + packets::{AttAttributeDataChild, AttErrorCode, Packet}, + utils::packet::{build_att_data, build_view_or_crash}, }; use tokio::{sync::mpsc::UnboundedReceiver, task::spawn_local}; use utils::start_test; @@ -32,9 +33,10 @@ fn initialize_manager_with_connection( } async fn pull_trans_id(events_rx: &mut UnboundedReceiver<MockCallbackEvents>) -> TransactionId { - let MockCallbackEvents::OnServerReadCharacteristic(_, trans_id, _, _, _) = - events_rx.recv().await.unwrap(); - trans_id + match events_rx.recv().await.unwrap() { + MockCallbackEvents::OnServerReadCharacteristic(_, trans_id, _, _, _) => trans_id, + MockCallbackEvents::OnServerWriteCharacteristic(_, trans_id, _, _, _, _, _) => trans_id, + } } #[test] @@ -196,3 +198,54 @@ fn test_invalid_trans_id() { assert_eq!(err, CallbackResponseError::NonExistentTransaction(invalid_trans_id)); }); } + +#[test] +fn test_write_characteristic_callback() { + start_test(async { + // arrange + let (callback_manager, mut callbacks_rx) = initialize_manager_with_connection(); + + // act: start write operation + let data = + build_view_or_crash(build_att_data(AttAttributeDataChild::RawData([1, 2].into()))); + let cloned_data = data.view().to_owned_packet(); + spawn_local(async move { + callback_manager.write_characteristic(CONN_ID, HANDLE_1, cloned_data.view()).await + }); + + // assert: verify the write callback is received + let MockCallbackEvents::OnServerWriteCharacteristic( + CONN_ID, _, HANDLE_1, 0, /* needs_response = */ true, false, recv_data + ) = callbacks_rx.recv().await.unwrap() else { + unreachable!() + }; + assert_eq!( + recv_data.view().get_raw_payload().collect::<Vec<_>>(), + data.view().get_raw_payload().collect::<Vec<_>>() + ); + }); +} + +#[test] +fn test_write_characteristic_response() { + start_test(async { + // arrange + let (callback_manager, mut callbacks_rx) = initialize_manager_with_connection(); + + // act: start write operation + let data = + build_view_or_crash(build_att_data(AttAttributeDataChild::RawData([1, 2].into()))); + let cloned_manager = callback_manager.clone(); + let pending_write = spawn_local(async move { + cloned_manager.write_characteristic(CONN_ID, HANDLE_1, data.view()).await + }); + // provide a response with some error code + let trans_id = pull_trans_id(&mut callbacks_rx).await; + callback_manager + .send_response(CONN_ID, trans_id, Err(AttErrorCode::WRITE_NOT_PERMITTED)) + .unwrap(); + + // assert: that the error code was received + assert_eq!(pending_write.await.unwrap(), Err(AttErrorCode::WRITE_NOT_PERMITTED)); + }); +} diff --git a/system/rust/tests/gatt_server_test.rs b/system/rust/tests/gatt_server_test.rs index 011c2000cd..4b48ef2b46 100644 --- a/system/rust/tests/gatt_server_test.rs +++ b/system/rust/tests/gatt_server_test.rs @@ -16,7 +16,8 @@ use bluetooth_core::{ }, packets::{ AttAttributeDataChild, AttBuilder, AttErrorCode, AttErrorResponseBuilder, AttOpcode, - AttReadRequestBuilder, AttReadResponseBuilder, GattServiceDeclarationValueBuilder, + AttReadRequestBuilder, AttReadResponseBuilder, AttWriteRequestBuilder, + AttWriteResponseBuilder, GattServiceDeclarationValueBuilder, Serializable, }, utils::packet::{build_att_data, build_att_view_or_crash}, }; @@ -56,7 +57,7 @@ fn create_server_and_open_connection(gatt: &mut GattModule) { characteristics: vec![GattCharacteristicWithHandle { handle: HANDLE_2, type_: UUID_2, - permissions: AttPermissions { readable: true, writable: false }, + permissions: AttPermissions { readable: true, writable: true }, }], }, ) @@ -222,3 +223,51 @@ fn test_characteristic_read() { ); }) } + +#[test] +fn test_characteristic_write() { + start_test(async move { + // arrange + let (mut gatt, mut data_rx, mut transport_rx) = start_gatt_module(); + + let data = AttAttributeDataChild::RawData([5, 6, 7, 8].into()); + + create_server_and_open_connection(&mut gatt); + data_rx.recv().await.unwrap(); + + // act + gatt.handle_packet( + CONN_ID, + build_att_view_or_crash(AttWriteRequestBuilder { + handle: HANDLE_2.into(), + value: build_att_data(data.clone()), + }) + .view(), + ) + .unwrap(); + let (tx, written_data) = + if let MockDatastoreEvents::WriteCharacteristic(CONN_ID, HANDLE_2, written_data, tx) = + data_rx.recv().await.unwrap() + { + (tx, written_data) + } else { + unreachable!() + }; + tx.send(Ok(())).unwrap(); + let (tcb_idx, resp) = transport_rx.recv().await.unwrap(); + + // assert + assert_eq!(tcb_idx, TCB_IDX); + assert_eq!( + resp, + AttBuilder { + opcode: AttOpcode::WRITE_RESPONSE, + _child_: AttWriteResponseBuilder {}.into() + } + ); + assert_eq!( + data.to_vec().unwrap(), + written_data.view().get_raw_payload().collect::<Vec<_>>() + ) + }) +} |