diff options
author | 2023-03-08 03:15:03 +0000 | |
---|---|---|
committer | 2023-03-09 01:26:28 +0000 | |
commit | 189169e520cb0e480fa2ba41de5ddfc19c7c35e7 (patch) | |
tree | bab192ba3962276b73a2791e58077d6f29f6860f | |
parent | e8b68ea94aeb440e63c255bab747c3dcc6565276 (diff) |
[Private GATT] Add support for MTU Exchange
Snoop MTU_REQ/RSP packets from legacy stack, and use them to track the
MTU used in the isolated server.
Bug: 255880936
Test: unit
Change-Id: Ifcaa35be47abdbf714b592318184701645b55800
-rw-r--r-- | system/rust/src/gatt.rs | 1 | ||||
-rw-r--r-- | system/rust/src/gatt/arbiter.rs | 58 | ||||
-rw-r--r-- | system/rust/src/gatt/ffi.rs | 3 | ||||
-rw-r--r-- | system/rust/src/gatt/mtu.rs | 251 | ||||
-rw-r--r-- | system/rust/src/gatt/server/att_server_bearer.rs | 132 | ||||
-rw-r--r-- | system/rust/src/gatt/server/indication_handler.rs | 58 | ||||
-rw-r--r-- | system/rust/src/packets.pdl | 8 | ||||
-rw-r--r-- | system/rust/src/utils.rs | 2 | ||||
-rw-r--r-- | system/rust/src/utils/packet.rs | 2 | ||||
-rw-r--r-- | system/rust/src/utils/task.rs | 52 | ||||
-rw-r--r-- | system/rust/tests/utils/mod.rs | 1 | ||||
-rw-r--r-- | system/stack/Android.bp | 1 | ||||
-rw-r--r-- | system/stack/arbiter/acl_arbiter.cc | 38 | ||||
-rw-r--r-- | system/stack/arbiter/acl_arbiter.h | 9 | ||||
-rw-r--r-- | system/stack/gatt/gatt_api.cc | 3 | ||||
-rw-r--r-- | system/stack/gatt/gatt_cl.cc | 4 | ||||
-rw-r--r-- | system/stack/gatt/gatt_sr.cc | 4 | ||||
-rw-r--r-- | system/test/Android.bp | 7 | ||||
-rw-r--r-- | system/test/mock/mock_stack_arbiter_acl_arbiter.cc | 52 |
19 files changed, 657 insertions, 29 deletions
diff --git a/system/rust/src/gatt.rs b/system/rust/src/gatt.rs index dd61f43683..c3e22b7004 100644 --- a/system/rust/src/gatt.rs +++ b/system/rust/src/gatt.rs @@ -7,6 +7,7 @@ pub mod channel; pub mod ffi; pub mod ids; pub mod mocks; +mod mtu; pub mod opcode_types; pub mod server; diff --git a/system/rust/src/gatt/arbiter.rs b/system/rust/src/gatt/arbiter.rs index d636f1fea7..cae6a3a2fe 100644 --- a/system/rust/src/gatt/arbiter.rs +++ b/system/rust/src/gatt/arbiter.rs @@ -7,13 +7,13 @@ use log::{error, info, trace}; use crate::{ do_in_rust_thread, - gatt::server::att_server_bearer::AttServerBearer, - packets::{OwnedAttView, OwnedPacket}, + packets::{AttOpcode, OwnedAttView, OwnedPacket}, }; use super::{ ffi::{InterceptAction, StoreCallbacksFromRust}, ids::{AdvertiserId, ConnectionId, ServerId, TransportIndex}, + mtu::MtuEvent, opcode_types::{classify_opcode, OperationType}, }; @@ -32,7 +32,14 @@ pub struct Arbiter { pub fn initialize_arbiter() { *ARBITER.lock().unwrap() = Some(Arbiter::new()); - StoreCallbacksFromRust(on_le_connect, on_le_disconnect, intercept_packet); + StoreCallbacksFromRust( + on_le_connect, + on_le_disconnect, + intercept_packet, + |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest), + |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)), + |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)), + ); } /// Acquire the mutex holding the Arbiter and provide a mutable reference to the @@ -93,6 +100,12 @@ impl Arbiter { let att = OwnedAttView::try_parse(packet).ok()?; + if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST { + // special case: this server opcode is handled by legacy stack, and we snoop + // on its handling, since the MTU is shared between the client + server + return None; + } + match classify_opcode(att.view().get_opcode()) { OperationType::Command | OperationType::Request | OperationType::Confirmation => { Some((att, conn_id)) @@ -127,6 +140,11 @@ impl Arbiter { info!("processing disconnection on transport {tcb_idx:?}"); self.transport_to_owned_connection.remove(&tcb_idx) } + + /// Look up the conn_id for a given tcb_idx, if present + pub fn get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId> { + self.transport_to_owned_connection.get(&tcb_idx).copied() + } } fn on_le_connect(tcb_idx: u8, advertiser: u8) { @@ -168,13 +186,30 @@ fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction { } } +fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) { + if let Some(conn_id) = with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)) { + do_in_rust_thread(move |modules| { + let Some(bearer) = modules.gatt_module.get_bearer(conn_id) else { + error!("Bearer for {conn_id:?} not found"); + return; + }; + if let Err(err) = bearer.handle_mtu_event(event) { + error!("{err:?}") + } + }); + } +} + #[cfg(test)] mod test { use super::*; use crate::{ gatt::ids::AttHandle, - packets::{AttBuilder, AttOpcode, AttReadRequestBuilder, Serializable}, + packets::{ + AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder, + Serializable, + }, }; const TCB_IDX: TransportIndex = TransportIndex(1); @@ -330,6 +365,21 @@ mod test { } #[test] + fn test_mtu_bypass() { + let mut arbiter = Arbiter::new(); + arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID); + arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID); + let packet = AttBuilder { + opcode: AttOpcode::EXCHANGE_MTU_REQUEST, + _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(), + }; + + let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into()); + + assert!(out.is_none()); + } + + #[test] fn test_packet_bypass_when_not_isolated() { let mut arbiter = Arbiter::new(); arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID); diff --git a/system/rust/src/gatt/ffi.rs b/system/rust/src/gatt/ffi.rs index 32dbb6c8c2..cf8f9583a2 100644 --- a/system/rust/src/gatt/ffi.rs +++ b/system/rust/src/gatt/ffi.rs @@ -149,6 +149,9 @@ mod inner { on_le_connect: fn(tcb_idx: u8, advertiser: u8), on_le_disconnect: fn(tcb_idx: u8), intercept_packet: fn(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction, + on_outgoing_mtu_req: fn(tcb_idx: u8), + on_incoming_mtu_resp: fn(tcb_idx: u8, mtu: usize), + on_incoming_mtu_req: fn(tcb_idx: u8, mtu: usize), ); /// Send an outgoing packet on the specified tcb_idx diff --git a/system/rust/src/gatt/mtu.rs b/system/rust/src/gatt/mtu.rs new file mode 100644 index 0000000000..838cfcda8b --- /dev/null +++ b/system/rust/src/gatt/mtu.rs @@ -0,0 +1,251 @@ +//! The MTU on an ATT bearer is determined either by L2CAP (if EATT) or by the +//! ATT_EXCHANGE_MTU procedure (if on an unenhanced bearer). +//! +//! In the latter case, the MTU may be either (1) unset, (2) pending, or (3) +//! set. If the MTU is pending, ATT notifications/indications may not be sent. +//! Refer to Core Spec 5.3 Vol 3F 3.4.2 MTU exchange for full details. + +use std::{cell::Cell, future::Future}; + +use anyhow::{bail, Result}; +use log::info; +use tokio::sync::OwnedMutexGuard; + +use crate::core::shared_mutex::SharedMutex; + +/// An MTU event that we have snooped +pub enum MtuEvent { + /// We have sent an MTU_REQ + OutgoingRequest, + /// We have received an MTU_RESP + IncomingResponse(usize), + /// We have received an MTU_REQ (and will immediately reply) + IncomingRequest(usize), +} + +/// The state of MTU negotiation on an unenhanced ATT bearer +pub struct AttMtu { + /// The MTU we have committed to (i.e. sent a REQ and got a RESP, or + /// vice-versa) + previous_mtu: Cell<usize>, + /// The MTU we have committed or are about to commit to (if a REQ is + /// pending) + stable_mtu: SharedMutex<usize>, + /// Lock guard held if we are currrently performing MTU negotiation + pending_exchange: Cell<Option<OwnedMutexGuard<usize>>>, +} + +// NOTE: this is only true for ATT, not EATT +const DEFAULT_ATT_MTU: usize = 23; + +impl AttMtu { + /// Constructor + pub fn new() -> Self { + Self { + previous_mtu: Cell::new(DEFAULT_ATT_MTU), + stable_mtu: SharedMutex::new(DEFAULT_ATT_MTU), + pending_exchange: Cell::new(None), + } + } + + /// Get the most recently negotiated MTU, or the default (if an MTU_REQ is + /// outstanding and we get an ATT_REQ) + pub fn snapshot_or_default(&self) -> usize { + self.stable_mtu.try_lock().as_deref().cloned().unwrap_or_else(|_| self.previous_mtu.get()) + } + + /// Get the most recently negotiated MTU, or block if negotiation is ongoing + /// (i.e. if an MTU_REQ is outstanding) + pub fn snapshot(&self) -> impl Future<Output = Option<usize>> { + let pending_snapshot = self.stable_mtu.lock(); + async move { pending_snapshot.await.as_deref().cloned() } + } + + /// Handle an MtuEvent and update the stored MTU + pub fn handle_event(&self, event: MtuEvent) -> Result<()> { + match event { + MtuEvent::OutgoingRequest => self.on_outgoing_request(), + MtuEvent::IncomingResponse(mtu) => self.on_incoming_response(mtu), + MtuEvent::IncomingRequest(mtu) => { + self.on_incoming_request(mtu); + Ok(()) + } + } + } + + fn on_outgoing_request(&self) -> Result<()> { + let Ok(pending_mtu) = self.stable_mtu.try_lock() else { + bail!("Sent ATT_EXCHANGE_MTU_REQ while an existing MTU exchange is taking place"); + }; + info!("Sending MTU_REQ, pausing indications/notifications"); + self.pending_exchange.replace(Some(pending_mtu)); + Ok(()) + } + + fn on_incoming_response(&self, mtu: usize) -> Result<()> { + let Some(mut pending_exchange) = self.pending_exchange.take() else { + bail!("Got ATT_EXCHANGE_MTU_RESP when transaction not taking place"); + }; + info!("Got an MTU_RESP of {mtu}"); + *pending_exchange = mtu; + // note: since MTU_REQ can be sent at most once, this is a no-op, as the + // stable_mtu will never again be blocked we do it anyway for clarity + self.previous_mtu.set(mtu); + Ok(()) + } + + fn on_incoming_request(&self, mtu: usize) { + self.previous_mtu.set(mtu); + if let Ok(mut stable_mtu) = self.stable_mtu.try_lock() { + info!("Accepted an MTU_REQ of {mtu:?}"); + *stable_mtu = mtu; + } else { + info!("Accepted an MTU_REQ while our own MTU_REQ was outstanding") + } + } +} + +#[cfg(test)] +mod test { + use crate::utils::task::{block_on_locally, try_await}; + + use super::*; + + const NEW_MTU: usize = 51; + const ANOTHER_NEW_MTU: usize = 52; + + #[test] + fn test_default_mtu() { + let mtu = AttMtu::new(); + + let stable_value = mtu.snapshot_or_default(); + let latest_value = tokio_test::block_on(mtu.snapshot()).unwrap(); + + assert_eq!(stable_value, DEFAULT_ATT_MTU); + assert_eq!(latest_value, DEFAULT_ATT_MTU); + } + + #[test] + fn test_guaranteed_mtu_during_client_negotiation() { + // arrange + let mtu = AttMtu::new(); + + // act: send an MTU_REQ and validate snapshotted value + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + let stable_value = mtu.snapshot_or_default(); + + // assert: we use the default MTU for requests handled + // while our request is pending + assert_eq!(stable_value, DEFAULT_ATT_MTU); + } + + #[test] + fn test_mtu_blocking_snapshot_during_client_negotiation() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: send an MTU_REQ + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + // take snapshot of pending future + let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err(); + // resolve MTU_REQ + mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap(); + + // assert: that the snapshot resolved with the NEW_MTU + assert_eq!(pending_mtu.await.unwrap(), NEW_MTU); + }); + } + + #[test] + fn test_receive_mtu_request() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: receive an MTU_REQ + mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap(); + // take snapshot + let snapshot = mtu.snapshot().await; + + // assert: that the snapshot resolved with the NEW_MTU + assert_eq!(snapshot.unwrap(), NEW_MTU); + }); + } + + #[test] + fn test_client_then_server_negotiation() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: send an MTU_REQ + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + // receive an MTU_RESP + mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap(); + // receive an MTU_REQ + mtu.handle_event(MtuEvent::IncomingRequest(ANOTHER_NEW_MTU)).unwrap(); + // take snapshot + let snapshot = mtu.snapshot().await; + + // assert: that the snapshot resolved with ANOTHER_NEW_MTU + assert_eq!(snapshot.unwrap(), ANOTHER_NEW_MTU); + }); + } + + #[test] + fn test_server_negotiation_then_pending_client_default_value() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: receive an MTU_REQ + mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap(); + // send a MTU_REQ + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + // take snapshot for requests + let snapshot = mtu.snapshot_or_default(); + + // assert: that the snapshot resolved to NEW_MTU + assert_eq!(snapshot, NEW_MTU); + }); + } + + #[test] + fn test_server_negotiation_then_pending_client_finalized_value() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: receive an MTU_REQ + mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap(); + // send a MTU_REQ + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + // take snapshot of pending future + let snapshot = try_await(mtu.snapshot()).await.unwrap_err(); + // receive MTU_RESP + mtu.handle_event(MtuEvent::IncomingResponse(ANOTHER_NEW_MTU)).unwrap(); + + // assert: that the snapshot resolved to ANOTHER_NEW_MTU + assert_eq!(snapshot.await.unwrap(), ANOTHER_NEW_MTU); + }); + } + + #[test] + fn test_mtu_dropped_while_pending() { + block_on_locally(async move { + // arrange + let mtu = AttMtu::new(); + + // act: send a MTU_REQ + mtu.handle_event(MtuEvent::OutgoingRequest).unwrap(); + // take snapshot and store pending future + let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err(); + // drop the mtu (when the bearer closes) + drop(mtu); + + // assert: that the snapshot resolves to None since the bearer is gone + assert!(pending_mtu.await.is_none()); + }); + } +} diff --git a/system/rust/src/gatt/server/att_server_bearer.rs b/system/rust/src/gatt/server/att_server_bearer.rs index 4534cc54a7..645c74205d 100644 --- a/system/rust/src/gatt/server/att_server_bearer.rs +++ b/system/rust/src/gatt/server/att_server_bearer.rs @@ -4,6 +4,7 @@ use std::{cell::Cell, future::Future}; +use anyhow::Result; use log::{error, trace, warn}; use tokio::task::spawn_local; @@ -14,6 +15,7 @@ use crate::{ }, gatt::{ ids::AttHandle, + mtu::{AttMtu, MtuEvent}, opcode_types::{classify_opcode, OperationType}, }, packets::{ @@ -34,8 +36,6 @@ enum AttRequestState<T: AttDatabase> { Pending(Option<OwnedHandle<()>>), } -const DEFAULT_ATT_MTU: usize = 23; - /// The errors that can occur while trying to send a packet #[derive(Debug)] pub enum SendError { @@ -51,7 +51,7 @@ pub enum SendError { pub struct AttServerBearer<T: AttDatabase> { // general send_packet: Box<dyn Fn(AttBuilder) -> Result<(), SerializeError>>, - mtu: Cell<usize>, + mtu: AttMtu, // request state curr_request: Cell<AttRequestState<T>>, @@ -71,7 +71,7 @@ impl<T: AttDatabase + Clone + 'static> AttServerBearer<T> { let (indication_handler, pending_confirmation) = IndicationHandler::new(db.clone()); Self { send_packet: Box::new(send_packet), - mtu: Cell::new(DEFAULT_ATT_MTU), + mtu: AttMtu::new(), curr_request: AttRequestState::Idle(AttRequestHandler::new(db)).into(), @@ -116,27 +116,43 @@ impl<T: AttDatabase + Clone + 'static> WeakBoxRef<'_, AttServerBearer<T>> { trace!("sending indication for handle {handle:?}"); let locked_indication_handler = self.indication_handler.lock(); + let pending_mtu = self.mtu.snapshot(); let this = self.downgrade(); async move { - locked_indication_handler + // first wait until we are at the head of the queue and are ready to send + // indications + let mut indication_handler = locked_indication_handler .await .ok_or_else(|| { warn!("indication for handle {handle:?} cancelled while queued since the connection dropped"); IndicationError::SendError(SendError::ConnectionDropped) - })? - .send(handle, data, |packet| this.try_send_packet(packet)) + })?; + // then, if MTU negotiation is taking place, wait for it to complete + let mtu = pending_mtu .await + .ok_or_else(|| { + warn!("indication for handle {handle:?} cancelled while waiting for MTU exchange to complete since the connection dropped"); + IndicationError::SendError(SendError::ConnectionDropped) + })?; + // finally, send, and wait for a response + indication_handler.send(handle, data, mtu, |packet| this.try_send_packet(packet)).await } } + /// Handle a snooped MTU event, to update the MTU we use for our various + /// operations + pub fn handle_mtu_event(&self, mtu_event: MtuEvent) -> Result<()> { + self.mtu.handle_event(mtu_event) + } + fn handle_request(&self, packet: AttView<'_>) { let curr_request = self.curr_request.replace(AttRequestState::Pending(None)); self.curr_request.replace(match curr_request { AttRequestState::Idle(mut request_handler) => { // even if the MTU is updated afterwards, 5.3 3F 3.4.2.2 states that the // request-time MTU should be used - let mtu = self.mtu.get(); + let mtu = self.mtu.snapshot_or_default(); let packet = packet.to_owned_packet(); let this = self.downgrade(); let task = spawn_local(async move { @@ -220,7 +236,7 @@ mod test { }, utils::{ packet::{build_att_data, build_att_view_or_crash}, - task::block_on_locally, + task::{block_on_locally, try_await}, }, }; @@ -557,4 +573,102 @@ mod test { )); }); } + + #[test] + fn test_single_indication_pending_mtu() { + block_on_locally(async { + // arrange: pending MTU negotiation + let (conn, mut rx) = open_connection(); + conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap(); + + // act: try to send an indication with a large payload size + let _ = + try_await(conn.as_ref().send_indication( + VALID_HANDLE, + AttAttributeDataChild::RawData((1..50).collect()), + )) + .await; + // then resolve the MTU negotiation with a large MTU + conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(100)).unwrap(); + + // assert: the indication was sent + assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION); + }); + } + + #[test] + fn test_single_indication_pending_mtu_fail() { + block_on_locally(async { + // arrange: pending MTU negotiation + let (conn, _) = open_connection(); + conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap(); + + // act: try to send an indication with a large payload size + let pending_mtu = + try_await(conn.as_ref().send_indication( + VALID_HANDLE, + AttAttributeDataChild::RawData((1..50).collect()), + )) + .await + .unwrap_err(); + // then resolve the MTU negotiation with a small MTU + conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(32)).unwrap(); + + // assert: the indication failed to send + assert!(matches!(pending_mtu.await, Err(IndicationError::DataExceedsMtu { .. }))); + }); + } + + #[test] + fn test_server_transaction_pending_mtu() { + block_on_locally(async { + // arrange: pending MTU negotiation + let (conn, mut rx) = open_connection(); + conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap(); + + // act: send server packet + conn.as_ref().handle_packet( + build_att_view_or_crash(AttReadRequestBuilder { + attribute_handle: VALID_HANDLE.into(), + }) + .view(), + ); + + // assert: that we reply even while the MTU req is outstanding + assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::READ_RESPONSE); + }); + } + + #[test] + fn test_queued_indication_pending_mtu_uses_mtu_on_dequeue() { + block_on_locally(async { + // arrange: an outstanding indication + let (conn, mut rx) = open_connection(); + let _ = + try_await(conn.as_ref().send_indication( + VALID_HANDLE, + AttAttributeDataChild::RawData([1, 2, 3].into()), + )) + .await; + rx.recv().await.unwrap(); // flush rx_queue + + // act: enqueue an indication with a large payload + let _ = + try_await(conn.as_ref().send_indication( + VALID_HANDLE, + AttAttributeDataChild::RawData((1..50).collect()), + )) + .await; + // then perform MTU negotiation to upgrade to a large MTU + conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap(); + conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(512)).unwrap(); + // finally resolve the first indication, so the second indication can be sent + conn.as_ref().handle_packet( + build_att_view_or_crash(AttHandleValueConfirmationBuilder {}).view(), + ); + + // assert: the second indication successfully sent (so it used the new MTU) + assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION); + }); + } } diff --git a/system/rust/src/gatt/server/indication_handler.rs b/system/rust/src/gatt/server/indication_handler.rs index e4469395fc..c69a68f91c 100644 --- a/system/rust/src/gatt/server/indication_handler.rs +++ b/system/rust/src/gatt/server/indication_handler.rs @@ -8,7 +8,7 @@ use tokio::{ use crate::{ gatt::ids::AttHandle, - packets::{AttAttributeDataChild, AttChild, AttHandleValueIndicationBuilder}, + packets::{AttAttributeDataChild, AttChild, AttHandleValueIndicationBuilder, Serializable}, utils::packet::build_att_data, }; @@ -20,6 +20,12 @@ use super::{ #[derive(Debug)] /// Errors that can occur while sending an indication pub enum IndicationError { + /// The provided data exceeds the MTU limitations + DataExceedsMtu { + /// The actual max payload size permitted + /// (ATT_MTU - 3, since 3 bytes are needed for the header) + mtu: usize, + }, /// The indicated attribute handle does not exist AttributeNotFound, /// The indicated attribute does not support indications @@ -47,8 +53,19 @@ impl<T: AttDatabase> IndicationHandler<T> { &mut self, handle: AttHandle, data: AttAttributeDataChild, + mtu: usize, send_packet: impl FnOnce(AttChild) -> Result<(), SendError>, ) -> Result<(), IndicationError> { + let data_size = data + .size_in_bits() + .map_err(SendError::SerializeError) + .map_err(IndicationError::SendError)?; + // As per Core Spec 5.3 Vol 3F 3.4.7.2, the indicated value must be at most + // ATT_MTU-3 + if data_size > (mtu - 3) * 8 { + return Err(IndicationError::DataExceedsMtu { mtu: mtu - 3 }); + } + if !self .db .snapshot() @@ -120,6 +137,7 @@ mod test { const HANDLE: AttHandle = AttHandle(1); const NONEXISTENT_HANDLE: AttHandle = AttHandle(2); const NON_INDICATE_HANDLE: AttHandle = AttHandle(3); + const MTU: usize = 32; fn get_data() -> AttAttributeDataChild { AttAttributeDataChild::RawData([1, 2, 3].into()) @@ -157,7 +175,7 @@ mod test { // act: send an indication spawn_local(async move { indication_handler - .send(HANDLE, get_data(), move |packet| { + .send(HANDLE, get_data(), MTU, move |packet| { tx.send(packet).unwrap(); Ok(()) }) @@ -187,7 +205,7 @@ mod test { // act: send an indication on a nonexistent handle let ret = indication_handler - .send(NONEXISTENT_HANDLE, get_data(), move |_| unreachable!()) + .send(NONEXISTENT_HANDLE, get_data(), MTU, move |_| unreachable!()) .await; // assert: that we failed with IndicationError::AttributeNotFound @@ -204,7 +222,7 @@ mod test { // act: send an indication on an attribute that does not support indications let ret = indication_handler - .send(NON_INDICATE_HANDLE, get_data(), move |_| unreachable!()) + .send(NON_INDICATE_HANDLE, get_data(), MTU, move |_| unreachable!()) .await; // assert: that we failed with IndicationError::IndicationsNotSupported @@ -223,7 +241,7 @@ mod test { // act: send an indication let pending_result = spawn_local(async move { indication_handler - .send(HANDLE, get_data(), move |packet| { + .send(HANDLE, get_data(), MTU, move |packet| { tx.send(packet).unwrap(); Ok(()) }) @@ -249,7 +267,7 @@ mod test { // act: send an indication let pending_result = spawn_local(async move { indication_handler - .send(HANDLE, get_data(), move |packet| { + .send(HANDLE, get_data(), MTU, move |packet| { tx.send(packet).unwrap(); Ok(()) }) @@ -281,7 +299,7 @@ mod test { // act: send an indication let pending_result = spawn_local(async move { indication_handler - .send(HANDLE, get_data(), move |packet| { + .send(HANDLE, get_data(), MTU, move |packet| { tx.send(packet).unwrap(); Ok(()) }) @@ -306,7 +324,6 @@ mod test { fn test_indication_timeout() { block_on_locally(async move { // arrange: send a few confirmations in advance - tokio::time::pause(); let (mut indication_handler, confirmation_watcher) = IndicationHandler::new(get_att_database()); let (tx, rx) = oneshot::channel(); @@ -317,7 +334,7 @@ mod test { let time_sent = Instant::now(); let pending_result = spawn_local(async move { indication_handler - .send(HANDLE, get_data(), move |packet| { + .send(HANDLE, get_data(), MTU, move |packet| { tx.send(packet).unwrap(); Ok(()) }) @@ -339,4 +356,27 @@ mod test { assert!(time_slept < Duration::from_secs(31)); }); } + + #[test] + fn test_mtu_exceeds() { + block_on_locally(async move { + // arrange + let (mut indication_handler, _confirmation_watcher) = + IndicationHandler::new(get_att_database()); + + // act: send an indication with an ATT_MTU of 4 and data length of 3 + let res = indication_handler + .send( + HANDLE, + AttAttributeDataChild::RawData([1, 2, 3].into()), + 4, + move |_| unreachable!(), + ) + .await; + + // assert: that we got the expected error, indicating the max data size (not the + // ATT_MTU, but ATT_MTU-3) + assert!(matches!(res, Err(IndicationError::DataExceedsMtu { mtu: 1 }))); + }); + } } diff --git a/system/rust/src/packets.pdl b/system/rust/src/packets.pdl index 3550045e01..85ea01998a 100644 --- a/system/rust/src/packets.pdl +++ b/system/rust/src/packets.pdl @@ -215,3 +215,11 @@ packet AttHandleValueIndication : Att(opcode = HANDLE_VALUE_INDICATION) { } packet AttHandleValueConfirmation : Att(opcode = HANDLE_VALUE_CONFIRMATION) {} + +packet AttExchangeMtuRequest : Att(opcode = EXCHANGE_MTU_REQUEST) { + mtu: 16, +} + +packet AttExchangeMtuResponse : Att(opcode = EXCHANGE_MTU_RESPONSE) { + mtu: 16, +} diff --git a/system/rust/src/utils.rs b/system/rust/src/utils.rs index bcc35dd44e..242ff80399 100644 --- a/system/rust/src/utils.rs +++ b/system/rust/src/utils.rs @@ -2,4 +2,6 @@ pub mod owned_handle; pub mod packet; + +#[cfg(test)] pub mod task; diff --git a/system/rust/src/utils/packet.rs b/system/rust/src/utils/packet.rs index fd30afb019..3a5dd51ee4 100644 --- a/system/rust/src/utils/packet.rs +++ b/system/rust/src/utils/packet.rs @@ -40,6 +40,8 @@ pub fn HACK_child_to_opcode(child: &AttChild) -> AttOpcode { AttChild::AttWriteResponse(_) => AttOpcode::WRITE_RESPONSE, AttChild::AttHandleValueIndication(_) => AttOpcode::HANDLE_VALUE_INDICATION, AttChild::AttHandleValueConfirmation(_) => AttOpcode::HANDLE_VALUE_CONFIRMATION, + AttChild::AttExchangeMtuRequest(_) => AttOpcode::EXCHANGE_MTU_REQUEST, + AttChild::AttExchangeMtuResponse(_) => AttOpcode::EXCHANGE_MTU_RESPONSE, } } diff --git a/system/rust/src/utils/task.rs b/system/rust/src/utils/task.rs index 24c7c6b141..7bf447df83 100644 --- a/system/rust/src/utils/task.rs +++ b/system/rust/src/utils/task.rs @@ -1,10 +1,54 @@ -//! This module provides utilities relating to async tasks +//! This module provides utilities relating to async tasks, typically for usage +//! only in test -use std::future::Future; +use std::{future::Future, time::Duration}; -use tokio::{runtime::Builder, task::LocalSet}; +use tokio::{ + runtime::Builder, + select, + task::{spawn_local, LocalSet}, +}; /// Run the supplied future on a single-threaded runtime pub fn block_on_locally<T>(f: impl Future<Output = T>) -> T { - LocalSet::new().block_on(&Builder::new_current_thread().enable_time().build().unwrap(), f) + LocalSet::new().block_on( + &Builder::new_current_thread().enable_time().start_paused(true).build().unwrap(), + async move { + select! { + t = f => t, + // NOTE: this time should be LARGER than any meaningful delay in the stack + _ = tokio::time::sleep(Duration::from_secs(100000)) => { + panic!("test appears to be stuck"); + }, + } + }, + ) +} + +/// Check if the supplied future immediately resolves. +/// Returns Ok(T) if it resolves, or Err(JoinHandle<T>) if it does not. +/// Correctly handles spurious wakeups (unlike Future::poll). +/// +/// Unlike spawn/spawn_local, try_await guarantees that the future has been +/// polled when it returns. In addition, it is safe to drop the returned future, +/// since the underlying future will still run (i.e. it will not be cancelled). +/// +/// Thus, this is useful in tests where we want to force a particular order of +/// events, rather than letting spawn_local enqueue a task to the executor at +/// *some* point in the future. +/// +/// MUST only be run in an environment where time is mocked. +pub async fn try_await<T: 'static>( + f: impl Future<Output = T> + 'static, +) -> Result<T, impl Future<Output = T>> { + let mut handle = spawn_local(f); + + select! { + t = &mut handle => Ok(t.unwrap()), + // NOTE: this time should be SMALLER than any meaningful delay in the stack + // since time is frozen in test, we don't need to worry about racing with anything + _ = tokio::time::sleep(Duration::from_millis(10)) => { + Err(async { handle.await.unwrap() }) + }, + } } diff --git a/system/rust/tests/utils/mod.rs b/system/rust/tests/utils/mod.rs index 458583ad1e..a187e90058 100644 --- a/system/rust/tests/utils/mod.rs +++ b/system/rust/tests/utils/mod.rs @@ -5,6 +5,7 @@ use tokio::task::LocalSet; pub fn start_test(f: impl Future<Output = ()>) { tokio_test::block_on(async move { bt_common::init_logging(); + tokio::time::pause(); LocalSet::new().run_until(f).await; }); } diff --git a/system/stack/Android.bp b/system/stack/Android.bp index fc7212f3fa..c1976a4d8d 100644 --- a/system/stack/Android.bp +++ b/system/stack/Android.bp @@ -629,6 +629,7 @@ cc_test { ":TestCommonMockFunctions", ":TestMockStackBtm", ":TestMockStackSdp", + ":TestMockStackArbiter", "gatt/gatt_utils.cc", "test/common/mock_eatt.cc", "test/common/mock_gatt_layer.cc", diff --git a/system/stack/arbiter/acl_arbiter.cc b/system/stack/arbiter/acl_arbiter.cc index a3f781a253..2dd81dfb51 100644 --- a/system/stack/arbiter/acl_arbiter.cc +++ b/system/stack/arbiter/acl_arbiter.cc @@ -47,6 +47,18 @@ class PassthroughAclArbiter : public AclArbiter { return InterceptAction::FORWARD; } + virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override { + // no-op + } + + virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) { + // no-op + } + + virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) { + // no-op + } + static PassthroughAclArbiter& Get() { static auto singleton = PassthroughAclArbiter(); return singleton; @@ -59,6 +71,9 @@ struct RustArbiterCallbacks { ::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect; ::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)> intercept_packet; + ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req; + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp; + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req; }; RustArbiterCallbacks callbacks_{}; @@ -88,6 +103,21 @@ class RustGattAclArbiter : public AclArbiter { return callbacks_.intercept_packet(tcb_idx, std::move(vec)); } + virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override { + LOG_DEBUG("Notifying Rust of outgoing MTU request"); + callbacks_.on_outgoing_mtu_req(tcb_idx); + } + + virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) { + LOG_DEBUG("Notifying Rust of incoming MTU response %zu", mtu); + callbacks_.on_incoming_mtu_resp(tcb_idx, mtu); + } + + virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) { + LOG_DEBUG("Notifying Rust of incoming MTU request %zu", mtu); + callbacks_.on_incoming_mtu_req(tcb_idx, mtu); + } + void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) { tGATT_TCB* p_tcb = gatt_get_tcb_by_idx(tcb_idx); if (p_tcb != nullptr) { @@ -116,9 +146,13 @@ void StoreCallbacksFromRust( ::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect, ::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect, ::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)> - intercept_packet) { + intercept_packet, + ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req, + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp, + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req) { LOG_INFO("Received callbacks from Rust, registering in Arbiter"); - callbacks_ = {on_le_connect, on_le_disconnect, intercept_packet}; + callbacks_ = {on_le_connect, on_le_disconnect, intercept_packet, + on_outgoing_mtu_req, on_incoming_mtu_resp, on_incoming_mtu_req}; } void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) { diff --git a/system/stack/arbiter/acl_arbiter.h b/system/stack/arbiter/acl_arbiter.h index 21f4ffba11..485adb6b80 100644 --- a/system/stack/arbiter/acl_arbiter.h +++ b/system/stack/arbiter/acl_arbiter.h @@ -44,6 +44,10 @@ class AclArbiter { virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx, const BT_HDR* packet) = 0; + virtual void OnOutgoingMtuReq(uint8_t tcb_idx) = 0; + virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) = 0; + virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) = 0; + AclArbiter() = default; AclArbiter(AclArbiter&& other) = default; AclArbiter& operator=(AclArbiter&& other) = default; @@ -54,7 +58,10 @@ void StoreCallbacksFromRust( ::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect, ::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect, ::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)> - intercept_packet); + intercept_packet, + ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req, + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp, + ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req); void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer); diff --git a/system/stack/gatt/gatt_api.cc b/system/stack/gatt/gatt_api.cc index 88f1e09687..1d6dff2469 100644 --- a/system/stack/gatt/gatt_api.cc +++ b/system/stack/gatt/gatt_api.cc @@ -38,6 +38,7 @@ #include "osi/include/allocator.h" #include "osi/include/list.h" #include "osi/include/log.h" +#include "stack/arbiter/acl_arbiter.h" #include "stack/btm/btm_dev.h" #include "stack/gatt/connection_manager.h" #include "stack/gatt/gatt_int.h" @@ -718,6 +719,8 @@ tGATT_STATUS GATTC_ConfigureMTU(uint16_t conn_id, uint16_t mtu) { gatt_cl_msg.mtu = mtu; LOG_DEBUG("Configuring ATT mtu size conn_id:%hu mtu:%hu", conn_id, mtu); + bluetooth::shim::arbiter::GetArbiter().OnOutgoingMtuReq(tcb_idx); + return attp_send_cl_msg(*p_clcb->p_tcb, p_clcb, GATT_REQ_MTU, &gatt_cl_msg); } diff --git a/system/stack/gatt/gatt_cl.cc b/system/stack/gatt/gatt_cl.cc index e2f88083c8..e2cfac0038 100644 --- a/system/stack/gatt/gatt_cl.cc +++ b/system/stack/gatt/gatt_cl.cc @@ -33,6 +33,7 @@ #include "osi/include/allocator.h" #include "osi/include/log.h" #include "osi/include/osi.h" +#include "stack/arbiter/acl_arbiter.h" #include "stack/eatt/eatt.h" #include "stack/include/bt_types.h" #include "types/bluetooth/uuid.h" @@ -1102,6 +1103,9 @@ void gatt_process_mtu_rsp(tGATT_TCB& tcb, tGATT_CLCB* p_clcb, uint16_t len, tcb.payload_size = mtu; } + bluetooth::shim::arbiter::GetArbiter().OnIncomingMtuResp(tcb.tcb_idx, + tcb.payload_size); + BTM_SetBleDataLength(tcb.peer_bda, tcb.payload_size + L2CAP_PKT_OVERHEAD); gatt_end_operation(p_clcb, status, NULL); diff --git a/system/stack/gatt/gatt_sr.cc b/system/stack/gatt/gatt_sr.cc index 9fefffd5c2..fddb3c2cf3 100644 --- a/system/stack/gatt/gatt_sr.cc +++ b/system/stack/gatt/gatt_sr.cc @@ -29,6 +29,7 @@ #include "osi/include/allocator.h" #include "osi/include/log.h" #include "osi/include/osi.h" +#include "stack/arbiter/acl_arbiter.h" #include "stack/eatt/eatt.h" #include "stack/include/bt_hdr.h" #include "stack/include/bt_types.h" @@ -831,6 +832,9 @@ static void gatts_process_mtu_req(tGATT_TCB& tcb, uint16_t cid, uint16_t len, attp_build_sr_msg(tcb, GATT_RSP_MTU, &gatt_sr_msg, tcb.payload_size); attp_send_sr_msg(tcb, cid, p_buf); + bluetooth::shim::arbiter::GetArbiter().OnIncomingMtuReq(tcb.tcb_idx, + tcb.payload_size); + tGATTS_DATA gatts_data; gatts_data.mtu = tcb.payload_size; /* Notify all registered applicaiton with new MTU size. Us a transaction ID */ diff --git a/system/test/Android.bp b/system/test/Android.bp index 3898e69af1..355fa3f0f9 100644 --- a/system/test/Android.bp +++ b/system/test/Android.bp @@ -199,6 +199,13 @@ filegroup { } filegroup { + name: "TestMockStackArbiter", + srcs: [ + "mock/mock_stack_arbiter_*.cc", + ], +} + +filegroup { name: "TestMockStackL2cap", srcs: [ "mock/mock_stack_l2cap_*.cc", diff --git a/system/test/mock/mock_stack_arbiter_acl_arbiter.cc b/system/test/mock/mock_stack_arbiter_acl_arbiter.cc new file mode 100644 index 0000000000..60ee86f1c6 --- /dev/null +++ b/system/test/mock/mock_stack_arbiter_acl_arbiter.cc @@ -0,0 +1,52 @@ +/* + * Copyright 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "stack/arbiter/acl_arbiter.h" + +namespace bluetooth { +namespace shim { +namespace arbiter { + +class MockAclArbiter : public AclArbiter { + public: + virtual void OnLeConnect(uint8_t tcb_idx, uint16_t advertiser_id) override {} + + virtual void OnLeDisconnect(uint8_t tcb_idx) override {} + + virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx, + const BT_HDR* packet) override { + return InterceptAction::FORWARD; + } + + virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {} + + virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {} + + virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {} + + static MockAclArbiter& Get() { + static auto singleton = MockAclArbiter(); + return singleton; + } +}; + +AclArbiter& GetArbiter() { + return static_cast<AclArbiter&>(MockAclArbiter::Get()); +} + +} // namespace arbiter +} // namespace shim +} // namespace bluetooth
\ No newline at end of file |