summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Rahul Arya <aryarahul@google.com> 2023-03-08 03:15:03 +0000
committer Rahul Arya <aryarahul@google.com> 2023-03-09 01:26:28 +0000
commit189169e520cb0e480fa2ba41de5ddfc19c7c35e7 (patch)
treebab192ba3962276b73a2791e58077d6f29f6860f
parente8b68ea94aeb440e63c255bab747c3dcc6565276 (diff)
[Private GATT] Add support for MTU Exchange
Snoop MTU_REQ/RSP packets from legacy stack, and use them to track the MTU used in the isolated server. Bug: 255880936 Test: unit Change-Id: Ifcaa35be47abdbf714b592318184701645b55800
-rw-r--r--system/rust/src/gatt.rs1
-rw-r--r--system/rust/src/gatt/arbiter.rs58
-rw-r--r--system/rust/src/gatt/ffi.rs3
-rw-r--r--system/rust/src/gatt/mtu.rs251
-rw-r--r--system/rust/src/gatt/server/att_server_bearer.rs132
-rw-r--r--system/rust/src/gatt/server/indication_handler.rs58
-rw-r--r--system/rust/src/packets.pdl8
-rw-r--r--system/rust/src/utils.rs2
-rw-r--r--system/rust/src/utils/packet.rs2
-rw-r--r--system/rust/src/utils/task.rs52
-rw-r--r--system/rust/tests/utils/mod.rs1
-rw-r--r--system/stack/Android.bp1
-rw-r--r--system/stack/arbiter/acl_arbiter.cc38
-rw-r--r--system/stack/arbiter/acl_arbiter.h9
-rw-r--r--system/stack/gatt/gatt_api.cc3
-rw-r--r--system/stack/gatt/gatt_cl.cc4
-rw-r--r--system/stack/gatt/gatt_sr.cc4
-rw-r--r--system/test/Android.bp7
-rw-r--r--system/test/mock/mock_stack_arbiter_acl_arbiter.cc52
19 files changed, 657 insertions, 29 deletions
diff --git a/system/rust/src/gatt.rs b/system/rust/src/gatt.rs
index dd61f43683..c3e22b7004 100644
--- a/system/rust/src/gatt.rs
+++ b/system/rust/src/gatt.rs
@@ -7,6 +7,7 @@ pub mod channel;
pub mod ffi;
pub mod ids;
pub mod mocks;
+mod mtu;
pub mod opcode_types;
pub mod server;
diff --git a/system/rust/src/gatt/arbiter.rs b/system/rust/src/gatt/arbiter.rs
index d636f1fea7..cae6a3a2fe 100644
--- a/system/rust/src/gatt/arbiter.rs
+++ b/system/rust/src/gatt/arbiter.rs
@@ -7,13 +7,13 @@ use log::{error, info, trace};
use crate::{
do_in_rust_thread,
- gatt::server::att_server_bearer::AttServerBearer,
- packets::{OwnedAttView, OwnedPacket},
+ packets::{AttOpcode, OwnedAttView, OwnedPacket},
};
use super::{
ffi::{InterceptAction, StoreCallbacksFromRust},
ids::{AdvertiserId, ConnectionId, ServerId, TransportIndex},
+ mtu::MtuEvent,
opcode_types::{classify_opcode, OperationType},
};
@@ -32,7 +32,14 @@ pub struct Arbiter {
pub fn initialize_arbiter() {
*ARBITER.lock().unwrap() = Some(Arbiter::new());
- StoreCallbacksFromRust(on_le_connect, on_le_disconnect, intercept_packet);
+ StoreCallbacksFromRust(
+ on_le_connect,
+ on_le_disconnect,
+ intercept_packet,
+ |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
+ |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
+ |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
+ );
}
/// Acquire the mutex holding the Arbiter and provide a mutable reference to the
@@ -93,6 +100,12 @@ impl Arbiter {
let att = OwnedAttView::try_parse(packet).ok()?;
+ if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST {
+ // special case: this server opcode is handled by legacy stack, and we snoop
+ // on its handling, since the MTU is shared between the client + server
+ return None;
+ }
+
match classify_opcode(att.view().get_opcode()) {
OperationType::Command | OperationType::Request | OperationType::Confirmation => {
Some((att, conn_id))
@@ -127,6 +140,11 @@ impl Arbiter {
info!("processing disconnection on transport {tcb_idx:?}");
self.transport_to_owned_connection.remove(&tcb_idx)
}
+
+ /// Look up the conn_id for a given tcb_idx, if present
+ pub fn get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId> {
+ self.transport_to_owned_connection.get(&tcb_idx).copied()
+ }
}
fn on_le_connect(tcb_idx: u8, advertiser: u8) {
@@ -168,13 +186,30 @@ fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
}
}
+fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
+ if let Some(conn_id) = with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)) {
+ do_in_rust_thread(move |modules| {
+ let Some(bearer) = modules.gatt_module.get_bearer(conn_id) else {
+ error!("Bearer for {conn_id:?} not found");
+ return;
+ };
+ if let Err(err) = bearer.handle_mtu_event(event) {
+ error!("{err:?}")
+ }
+ });
+ }
+}
+
#[cfg(test)]
mod test {
use super::*;
use crate::{
gatt::ids::AttHandle,
- packets::{AttBuilder, AttOpcode, AttReadRequestBuilder, Serializable},
+ packets::{
+ AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder,
+ Serializable,
+ },
};
const TCB_IDX: TransportIndex = TransportIndex(1);
@@ -330,6 +365,21 @@ mod test {
}
#[test]
+ fn test_mtu_bypass() {
+ let mut arbiter = Arbiter::new();
+ arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
+ arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
+ let packet = AttBuilder {
+ opcode: AttOpcode::EXCHANGE_MTU_REQUEST,
+ _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(),
+ };
+
+ let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());
+
+ assert!(out.is_none());
+ }
+
+ #[test]
fn test_packet_bypass_when_not_isolated() {
let mut arbiter = Arbiter::new();
arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
diff --git a/system/rust/src/gatt/ffi.rs b/system/rust/src/gatt/ffi.rs
index 32dbb6c8c2..cf8f9583a2 100644
--- a/system/rust/src/gatt/ffi.rs
+++ b/system/rust/src/gatt/ffi.rs
@@ -149,6 +149,9 @@ mod inner {
on_le_connect: fn(tcb_idx: u8, advertiser: u8),
on_le_disconnect: fn(tcb_idx: u8),
intercept_packet: fn(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction,
+ on_outgoing_mtu_req: fn(tcb_idx: u8),
+ on_incoming_mtu_resp: fn(tcb_idx: u8, mtu: usize),
+ on_incoming_mtu_req: fn(tcb_idx: u8, mtu: usize),
);
/// Send an outgoing packet on the specified tcb_idx
diff --git a/system/rust/src/gatt/mtu.rs b/system/rust/src/gatt/mtu.rs
new file mode 100644
index 0000000000..838cfcda8b
--- /dev/null
+++ b/system/rust/src/gatt/mtu.rs
@@ -0,0 +1,251 @@
+//! The MTU on an ATT bearer is determined either by L2CAP (if EATT) or by the
+//! ATT_EXCHANGE_MTU procedure (if on an unenhanced bearer).
+//!
+//! In the latter case, the MTU may be either (1) unset, (2) pending, or (3)
+//! set. If the MTU is pending, ATT notifications/indications may not be sent.
+//! Refer to Core Spec 5.3 Vol 3F 3.4.2 MTU exchange for full details.
+
+use std::{cell::Cell, future::Future};
+
+use anyhow::{bail, Result};
+use log::info;
+use tokio::sync::OwnedMutexGuard;
+
+use crate::core::shared_mutex::SharedMutex;
+
+/// An MTU event that we have snooped
+pub enum MtuEvent {
+ /// We have sent an MTU_REQ
+ OutgoingRequest,
+ /// We have received an MTU_RESP
+ IncomingResponse(usize),
+ /// We have received an MTU_REQ (and will immediately reply)
+ IncomingRequest(usize),
+}
+
+/// The state of MTU negotiation on an unenhanced ATT bearer
+pub struct AttMtu {
+ /// The MTU we have committed to (i.e. sent a REQ and got a RESP, or
+ /// vice-versa)
+ previous_mtu: Cell<usize>,
+ /// The MTU we have committed or are about to commit to (if a REQ is
+ /// pending)
+ stable_mtu: SharedMutex<usize>,
+ /// Lock guard held if we are currrently performing MTU negotiation
+ pending_exchange: Cell<Option<OwnedMutexGuard<usize>>>,
+}
+
+// NOTE: this is only true for ATT, not EATT
+const DEFAULT_ATT_MTU: usize = 23;
+
+impl AttMtu {
+ /// Constructor
+ pub fn new() -> Self {
+ Self {
+ previous_mtu: Cell::new(DEFAULT_ATT_MTU),
+ stable_mtu: SharedMutex::new(DEFAULT_ATT_MTU),
+ pending_exchange: Cell::new(None),
+ }
+ }
+
+ /// Get the most recently negotiated MTU, or the default (if an MTU_REQ is
+ /// outstanding and we get an ATT_REQ)
+ pub fn snapshot_or_default(&self) -> usize {
+ self.stable_mtu.try_lock().as_deref().cloned().unwrap_or_else(|_| self.previous_mtu.get())
+ }
+
+ /// Get the most recently negotiated MTU, or block if negotiation is ongoing
+ /// (i.e. if an MTU_REQ is outstanding)
+ pub fn snapshot(&self) -> impl Future<Output = Option<usize>> {
+ let pending_snapshot = self.stable_mtu.lock();
+ async move { pending_snapshot.await.as_deref().cloned() }
+ }
+
+ /// Handle an MtuEvent and update the stored MTU
+ pub fn handle_event(&self, event: MtuEvent) -> Result<()> {
+ match event {
+ MtuEvent::OutgoingRequest => self.on_outgoing_request(),
+ MtuEvent::IncomingResponse(mtu) => self.on_incoming_response(mtu),
+ MtuEvent::IncomingRequest(mtu) => {
+ self.on_incoming_request(mtu);
+ Ok(())
+ }
+ }
+ }
+
+ fn on_outgoing_request(&self) -> Result<()> {
+ let Ok(pending_mtu) = self.stable_mtu.try_lock() else {
+ bail!("Sent ATT_EXCHANGE_MTU_REQ while an existing MTU exchange is taking place");
+ };
+ info!("Sending MTU_REQ, pausing indications/notifications");
+ self.pending_exchange.replace(Some(pending_mtu));
+ Ok(())
+ }
+
+ fn on_incoming_response(&self, mtu: usize) -> Result<()> {
+ let Some(mut pending_exchange) = self.pending_exchange.take() else {
+ bail!("Got ATT_EXCHANGE_MTU_RESP when transaction not taking place");
+ };
+ info!("Got an MTU_RESP of {mtu}");
+ *pending_exchange = mtu;
+ // note: since MTU_REQ can be sent at most once, this is a no-op, as the
+ // stable_mtu will never again be blocked we do it anyway for clarity
+ self.previous_mtu.set(mtu);
+ Ok(())
+ }
+
+ fn on_incoming_request(&self, mtu: usize) {
+ self.previous_mtu.set(mtu);
+ if let Ok(mut stable_mtu) = self.stable_mtu.try_lock() {
+ info!("Accepted an MTU_REQ of {mtu:?}");
+ *stable_mtu = mtu;
+ } else {
+ info!("Accepted an MTU_REQ while our own MTU_REQ was outstanding")
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::utils::task::{block_on_locally, try_await};
+
+ use super::*;
+
+ const NEW_MTU: usize = 51;
+ const ANOTHER_NEW_MTU: usize = 52;
+
+ #[test]
+ fn test_default_mtu() {
+ let mtu = AttMtu::new();
+
+ let stable_value = mtu.snapshot_or_default();
+ let latest_value = tokio_test::block_on(mtu.snapshot()).unwrap();
+
+ assert_eq!(stable_value, DEFAULT_ATT_MTU);
+ assert_eq!(latest_value, DEFAULT_ATT_MTU);
+ }
+
+ #[test]
+ fn test_guaranteed_mtu_during_client_negotiation() {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: send an MTU_REQ and validate snapshotted value
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ let stable_value = mtu.snapshot_or_default();
+
+ // assert: we use the default MTU for requests handled
+ // while our request is pending
+ assert_eq!(stable_value, DEFAULT_ATT_MTU);
+ }
+
+ #[test]
+ fn test_mtu_blocking_snapshot_during_client_negotiation() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: send an MTU_REQ
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ // take snapshot of pending future
+ let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err();
+ // resolve MTU_REQ
+ mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap();
+
+ // assert: that the snapshot resolved with the NEW_MTU
+ assert_eq!(pending_mtu.await.unwrap(), NEW_MTU);
+ });
+ }
+
+ #[test]
+ fn test_receive_mtu_request() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: receive an MTU_REQ
+ mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
+ // take snapshot
+ let snapshot = mtu.snapshot().await;
+
+ // assert: that the snapshot resolved with the NEW_MTU
+ assert_eq!(snapshot.unwrap(), NEW_MTU);
+ });
+ }
+
+ #[test]
+ fn test_client_then_server_negotiation() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: send an MTU_REQ
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ // receive an MTU_RESP
+ mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap();
+ // receive an MTU_REQ
+ mtu.handle_event(MtuEvent::IncomingRequest(ANOTHER_NEW_MTU)).unwrap();
+ // take snapshot
+ let snapshot = mtu.snapshot().await;
+
+ // assert: that the snapshot resolved with ANOTHER_NEW_MTU
+ assert_eq!(snapshot.unwrap(), ANOTHER_NEW_MTU);
+ });
+ }
+
+ #[test]
+ fn test_server_negotiation_then_pending_client_default_value() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: receive an MTU_REQ
+ mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
+ // send a MTU_REQ
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ // take snapshot for requests
+ let snapshot = mtu.snapshot_or_default();
+
+ // assert: that the snapshot resolved to NEW_MTU
+ assert_eq!(snapshot, NEW_MTU);
+ });
+ }
+
+ #[test]
+ fn test_server_negotiation_then_pending_client_finalized_value() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: receive an MTU_REQ
+ mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
+ // send a MTU_REQ
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ // take snapshot of pending future
+ let snapshot = try_await(mtu.snapshot()).await.unwrap_err();
+ // receive MTU_RESP
+ mtu.handle_event(MtuEvent::IncomingResponse(ANOTHER_NEW_MTU)).unwrap();
+
+ // assert: that the snapshot resolved to ANOTHER_NEW_MTU
+ assert_eq!(snapshot.await.unwrap(), ANOTHER_NEW_MTU);
+ });
+ }
+
+ #[test]
+ fn test_mtu_dropped_while_pending() {
+ block_on_locally(async move {
+ // arrange
+ let mtu = AttMtu::new();
+
+ // act: send a MTU_REQ
+ mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
+ // take snapshot and store pending future
+ let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err();
+ // drop the mtu (when the bearer closes)
+ drop(mtu);
+
+ // assert: that the snapshot resolves to None since the bearer is gone
+ assert!(pending_mtu.await.is_none());
+ });
+ }
+}
diff --git a/system/rust/src/gatt/server/att_server_bearer.rs b/system/rust/src/gatt/server/att_server_bearer.rs
index 4534cc54a7..645c74205d 100644
--- a/system/rust/src/gatt/server/att_server_bearer.rs
+++ b/system/rust/src/gatt/server/att_server_bearer.rs
@@ -4,6 +4,7 @@
use std::{cell::Cell, future::Future};
+use anyhow::Result;
use log::{error, trace, warn};
use tokio::task::spawn_local;
@@ -14,6 +15,7 @@ use crate::{
},
gatt::{
ids::AttHandle,
+ mtu::{AttMtu, MtuEvent},
opcode_types::{classify_opcode, OperationType},
},
packets::{
@@ -34,8 +36,6 @@ enum AttRequestState<T: AttDatabase> {
Pending(Option<OwnedHandle<()>>),
}
-const DEFAULT_ATT_MTU: usize = 23;
-
/// The errors that can occur while trying to send a packet
#[derive(Debug)]
pub enum SendError {
@@ -51,7 +51,7 @@ pub enum SendError {
pub struct AttServerBearer<T: AttDatabase> {
// general
send_packet: Box<dyn Fn(AttBuilder) -> Result<(), SerializeError>>,
- mtu: Cell<usize>,
+ mtu: AttMtu,
// request state
curr_request: Cell<AttRequestState<T>>,
@@ -71,7 +71,7 @@ impl<T: AttDatabase + Clone + 'static> AttServerBearer<T> {
let (indication_handler, pending_confirmation) = IndicationHandler::new(db.clone());
Self {
send_packet: Box::new(send_packet),
- mtu: Cell::new(DEFAULT_ATT_MTU),
+ mtu: AttMtu::new(),
curr_request: AttRequestState::Idle(AttRequestHandler::new(db)).into(),
@@ -116,27 +116,43 @@ impl<T: AttDatabase + Clone + 'static> WeakBoxRef<'_, AttServerBearer<T>> {
trace!("sending indication for handle {handle:?}");
let locked_indication_handler = self.indication_handler.lock();
+ let pending_mtu = self.mtu.snapshot();
let this = self.downgrade();
async move {
- locked_indication_handler
+ // first wait until we are at the head of the queue and are ready to send
+ // indications
+ let mut indication_handler = locked_indication_handler
.await
.ok_or_else(|| {
warn!("indication for handle {handle:?} cancelled while queued since the connection dropped");
IndicationError::SendError(SendError::ConnectionDropped)
- })?
- .send(handle, data, |packet| this.try_send_packet(packet))
+ })?;
+ // then, if MTU negotiation is taking place, wait for it to complete
+ let mtu = pending_mtu
.await
+ .ok_or_else(|| {
+ warn!("indication for handle {handle:?} cancelled while waiting for MTU exchange to complete since the connection dropped");
+ IndicationError::SendError(SendError::ConnectionDropped)
+ })?;
+ // finally, send, and wait for a response
+ indication_handler.send(handle, data, mtu, |packet| this.try_send_packet(packet)).await
}
}
+ /// Handle a snooped MTU event, to update the MTU we use for our various
+ /// operations
+ pub fn handle_mtu_event(&self, mtu_event: MtuEvent) -> Result<()> {
+ self.mtu.handle_event(mtu_event)
+ }
+
fn handle_request(&self, packet: AttView<'_>) {
let curr_request = self.curr_request.replace(AttRequestState::Pending(None));
self.curr_request.replace(match curr_request {
AttRequestState::Idle(mut request_handler) => {
// even if the MTU is updated afterwards, 5.3 3F 3.4.2.2 states that the
// request-time MTU should be used
- let mtu = self.mtu.get();
+ let mtu = self.mtu.snapshot_or_default();
let packet = packet.to_owned_packet();
let this = self.downgrade();
let task = spawn_local(async move {
@@ -220,7 +236,7 @@ mod test {
},
utils::{
packet::{build_att_data, build_att_view_or_crash},
- task::block_on_locally,
+ task::{block_on_locally, try_await},
},
};
@@ -557,4 +573,102 @@ mod test {
));
});
}
+
+ #[test]
+ fn test_single_indication_pending_mtu() {
+ block_on_locally(async {
+ // arrange: pending MTU negotiation
+ let (conn, mut rx) = open_connection();
+ conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();
+
+ // act: try to send an indication with a large payload size
+ let _ =
+ try_await(conn.as_ref().send_indication(
+ VALID_HANDLE,
+ AttAttributeDataChild::RawData((1..50).collect()),
+ ))
+ .await;
+ // then resolve the MTU negotiation with a large MTU
+ conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(100)).unwrap();
+
+ // assert: the indication was sent
+ assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION);
+ });
+ }
+
+ #[test]
+ fn test_single_indication_pending_mtu_fail() {
+ block_on_locally(async {
+ // arrange: pending MTU negotiation
+ let (conn, _) = open_connection();
+ conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();
+
+ // act: try to send an indication with a large payload size
+ let pending_mtu =
+ try_await(conn.as_ref().send_indication(
+ VALID_HANDLE,
+ AttAttributeDataChild::RawData((1..50).collect()),
+ ))
+ .await
+ .unwrap_err();
+ // then resolve the MTU negotiation with a small MTU
+ conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(32)).unwrap();
+
+ // assert: the indication failed to send
+ assert!(matches!(pending_mtu.await, Err(IndicationError::DataExceedsMtu { .. })));
+ });
+ }
+
+ #[test]
+ fn test_server_transaction_pending_mtu() {
+ block_on_locally(async {
+ // arrange: pending MTU negotiation
+ let (conn, mut rx) = open_connection();
+ conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();
+
+ // act: send server packet
+ conn.as_ref().handle_packet(
+ build_att_view_or_crash(AttReadRequestBuilder {
+ attribute_handle: VALID_HANDLE.into(),
+ })
+ .view(),
+ );
+
+ // assert: that we reply even while the MTU req is outstanding
+ assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::READ_RESPONSE);
+ });
+ }
+
+ #[test]
+ fn test_queued_indication_pending_mtu_uses_mtu_on_dequeue() {
+ block_on_locally(async {
+ // arrange: an outstanding indication
+ let (conn, mut rx) = open_connection();
+ let _ =
+ try_await(conn.as_ref().send_indication(
+ VALID_HANDLE,
+ AttAttributeDataChild::RawData([1, 2, 3].into()),
+ ))
+ .await;
+ rx.recv().await.unwrap(); // flush rx_queue
+
+ // act: enqueue an indication with a large payload
+ let _ =
+ try_await(conn.as_ref().send_indication(
+ VALID_HANDLE,
+ AttAttributeDataChild::RawData((1..50).collect()),
+ ))
+ .await;
+ // then perform MTU negotiation to upgrade to a large MTU
+ conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();
+ conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(512)).unwrap();
+ // finally resolve the first indication, so the second indication can be sent
+ conn.as_ref().handle_packet(
+ build_att_view_or_crash(AttHandleValueConfirmationBuilder {}).view(),
+ );
+
+ // assert: the second indication successfully sent (so it used the new MTU)
+ assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION);
+ });
+ }
}
diff --git a/system/rust/src/gatt/server/indication_handler.rs b/system/rust/src/gatt/server/indication_handler.rs
index e4469395fc..c69a68f91c 100644
--- a/system/rust/src/gatt/server/indication_handler.rs
+++ b/system/rust/src/gatt/server/indication_handler.rs
@@ -8,7 +8,7 @@ use tokio::{
use crate::{
gatt::ids::AttHandle,
- packets::{AttAttributeDataChild, AttChild, AttHandleValueIndicationBuilder},
+ packets::{AttAttributeDataChild, AttChild, AttHandleValueIndicationBuilder, Serializable},
utils::packet::build_att_data,
};
@@ -20,6 +20,12 @@ use super::{
#[derive(Debug)]
/// Errors that can occur while sending an indication
pub enum IndicationError {
+ /// The provided data exceeds the MTU limitations
+ DataExceedsMtu {
+ /// The actual max payload size permitted
+ /// (ATT_MTU - 3, since 3 bytes are needed for the header)
+ mtu: usize,
+ },
/// The indicated attribute handle does not exist
AttributeNotFound,
/// The indicated attribute does not support indications
@@ -47,8 +53,19 @@ impl<T: AttDatabase> IndicationHandler<T> {
&mut self,
handle: AttHandle,
data: AttAttributeDataChild,
+ mtu: usize,
send_packet: impl FnOnce(AttChild) -> Result<(), SendError>,
) -> Result<(), IndicationError> {
+ let data_size = data
+ .size_in_bits()
+ .map_err(SendError::SerializeError)
+ .map_err(IndicationError::SendError)?;
+ // As per Core Spec 5.3 Vol 3F 3.4.7.2, the indicated value must be at most
+ // ATT_MTU-3
+ if data_size > (mtu - 3) * 8 {
+ return Err(IndicationError::DataExceedsMtu { mtu: mtu - 3 });
+ }
+
if !self
.db
.snapshot()
@@ -120,6 +137,7 @@ mod test {
const HANDLE: AttHandle = AttHandle(1);
const NONEXISTENT_HANDLE: AttHandle = AttHandle(2);
const NON_INDICATE_HANDLE: AttHandle = AttHandle(3);
+ const MTU: usize = 32;
fn get_data() -> AttAttributeDataChild {
AttAttributeDataChild::RawData([1, 2, 3].into())
@@ -157,7 +175,7 @@ mod test {
// act: send an indication
spawn_local(async move {
indication_handler
- .send(HANDLE, get_data(), move |packet| {
+ .send(HANDLE, get_data(), MTU, move |packet| {
tx.send(packet).unwrap();
Ok(())
})
@@ -187,7 +205,7 @@ mod test {
// act: send an indication on a nonexistent handle
let ret = indication_handler
- .send(NONEXISTENT_HANDLE, get_data(), move |_| unreachable!())
+ .send(NONEXISTENT_HANDLE, get_data(), MTU, move |_| unreachable!())
.await;
// assert: that we failed with IndicationError::AttributeNotFound
@@ -204,7 +222,7 @@ mod test {
// act: send an indication on an attribute that does not support indications
let ret = indication_handler
- .send(NON_INDICATE_HANDLE, get_data(), move |_| unreachable!())
+ .send(NON_INDICATE_HANDLE, get_data(), MTU, move |_| unreachable!())
.await;
// assert: that we failed with IndicationError::IndicationsNotSupported
@@ -223,7 +241,7 @@ mod test {
// act: send an indication
let pending_result = spawn_local(async move {
indication_handler
- .send(HANDLE, get_data(), move |packet| {
+ .send(HANDLE, get_data(), MTU, move |packet| {
tx.send(packet).unwrap();
Ok(())
})
@@ -249,7 +267,7 @@ mod test {
// act: send an indication
let pending_result = spawn_local(async move {
indication_handler
- .send(HANDLE, get_data(), move |packet| {
+ .send(HANDLE, get_data(), MTU, move |packet| {
tx.send(packet).unwrap();
Ok(())
})
@@ -281,7 +299,7 @@ mod test {
// act: send an indication
let pending_result = spawn_local(async move {
indication_handler
- .send(HANDLE, get_data(), move |packet| {
+ .send(HANDLE, get_data(), MTU, move |packet| {
tx.send(packet).unwrap();
Ok(())
})
@@ -306,7 +324,6 @@ mod test {
fn test_indication_timeout() {
block_on_locally(async move {
// arrange: send a few confirmations in advance
- tokio::time::pause();
let (mut indication_handler, confirmation_watcher) =
IndicationHandler::new(get_att_database());
let (tx, rx) = oneshot::channel();
@@ -317,7 +334,7 @@ mod test {
let time_sent = Instant::now();
let pending_result = spawn_local(async move {
indication_handler
- .send(HANDLE, get_data(), move |packet| {
+ .send(HANDLE, get_data(), MTU, move |packet| {
tx.send(packet).unwrap();
Ok(())
})
@@ -339,4 +356,27 @@ mod test {
assert!(time_slept < Duration::from_secs(31));
});
}
+
+ #[test]
+ fn test_mtu_exceeds() {
+ block_on_locally(async move {
+ // arrange
+ let (mut indication_handler, _confirmation_watcher) =
+ IndicationHandler::new(get_att_database());
+
+ // act: send an indication with an ATT_MTU of 4 and data length of 3
+ let res = indication_handler
+ .send(
+ HANDLE,
+ AttAttributeDataChild::RawData([1, 2, 3].into()),
+ 4,
+ move |_| unreachable!(),
+ )
+ .await;
+
+ // assert: that we got the expected error, indicating the max data size (not the
+ // ATT_MTU, but ATT_MTU-3)
+ assert!(matches!(res, Err(IndicationError::DataExceedsMtu { mtu: 1 })));
+ });
+ }
}
diff --git a/system/rust/src/packets.pdl b/system/rust/src/packets.pdl
index 3550045e01..85ea01998a 100644
--- a/system/rust/src/packets.pdl
+++ b/system/rust/src/packets.pdl
@@ -215,3 +215,11 @@ packet AttHandleValueIndication : Att(opcode = HANDLE_VALUE_INDICATION) {
}
packet AttHandleValueConfirmation : Att(opcode = HANDLE_VALUE_CONFIRMATION) {}
+
+packet AttExchangeMtuRequest : Att(opcode = EXCHANGE_MTU_REQUEST) {
+ mtu: 16,
+}
+
+packet AttExchangeMtuResponse : Att(opcode = EXCHANGE_MTU_RESPONSE) {
+ mtu: 16,
+}
diff --git a/system/rust/src/utils.rs b/system/rust/src/utils.rs
index bcc35dd44e..242ff80399 100644
--- a/system/rust/src/utils.rs
+++ b/system/rust/src/utils.rs
@@ -2,4 +2,6 @@
pub mod owned_handle;
pub mod packet;
+
+#[cfg(test)]
pub mod task;
diff --git a/system/rust/src/utils/packet.rs b/system/rust/src/utils/packet.rs
index fd30afb019..3a5dd51ee4 100644
--- a/system/rust/src/utils/packet.rs
+++ b/system/rust/src/utils/packet.rs
@@ -40,6 +40,8 @@ pub fn HACK_child_to_opcode(child: &AttChild) -> AttOpcode {
AttChild::AttWriteResponse(_) => AttOpcode::WRITE_RESPONSE,
AttChild::AttHandleValueIndication(_) => AttOpcode::HANDLE_VALUE_INDICATION,
AttChild::AttHandleValueConfirmation(_) => AttOpcode::HANDLE_VALUE_CONFIRMATION,
+ AttChild::AttExchangeMtuRequest(_) => AttOpcode::EXCHANGE_MTU_REQUEST,
+ AttChild::AttExchangeMtuResponse(_) => AttOpcode::EXCHANGE_MTU_RESPONSE,
}
}
diff --git a/system/rust/src/utils/task.rs b/system/rust/src/utils/task.rs
index 24c7c6b141..7bf447df83 100644
--- a/system/rust/src/utils/task.rs
+++ b/system/rust/src/utils/task.rs
@@ -1,10 +1,54 @@
-//! This module provides utilities relating to async tasks
+//! This module provides utilities relating to async tasks, typically for usage
+//! only in test
-use std::future::Future;
+use std::{future::Future, time::Duration};
-use tokio::{runtime::Builder, task::LocalSet};
+use tokio::{
+ runtime::Builder,
+ select,
+ task::{spawn_local, LocalSet},
+};
/// Run the supplied future on a single-threaded runtime
pub fn block_on_locally<T>(f: impl Future<Output = T>) -> T {
- LocalSet::new().block_on(&Builder::new_current_thread().enable_time().build().unwrap(), f)
+ LocalSet::new().block_on(
+ &Builder::new_current_thread().enable_time().start_paused(true).build().unwrap(),
+ async move {
+ select! {
+ t = f => t,
+ // NOTE: this time should be LARGER than any meaningful delay in the stack
+ _ = tokio::time::sleep(Duration::from_secs(100000)) => {
+ panic!("test appears to be stuck");
+ },
+ }
+ },
+ )
+}
+
+/// Check if the supplied future immediately resolves.
+/// Returns Ok(T) if it resolves, or Err(JoinHandle<T>) if it does not.
+/// Correctly handles spurious wakeups (unlike Future::poll).
+///
+/// Unlike spawn/spawn_local, try_await guarantees that the future has been
+/// polled when it returns. In addition, it is safe to drop the returned future,
+/// since the underlying future will still run (i.e. it will not be cancelled).
+///
+/// Thus, this is useful in tests where we want to force a particular order of
+/// events, rather than letting spawn_local enqueue a task to the executor at
+/// *some* point in the future.
+///
+/// MUST only be run in an environment where time is mocked.
+pub async fn try_await<T: 'static>(
+ f: impl Future<Output = T> + 'static,
+) -> Result<T, impl Future<Output = T>> {
+ let mut handle = spawn_local(f);
+
+ select! {
+ t = &mut handle => Ok(t.unwrap()),
+ // NOTE: this time should be SMALLER than any meaningful delay in the stack
+ // since time is frozen in test, we don't need to worry about racing with anything
+ _ = tokio::time::sleep(Duration::from_millis(10)) => {
+ Err(async { handle.await.unwrap() })
+ },
+ }
}
diff --git a/system/rust/tests/utils/mod.rs b/system/rust/tests/utils/mod.rs
index 458583ad1e..a187e90058 100644
--- a/system/rust/tests/utils/mod.rs
+++ b/system/rust/tests/utils/mod.rs
@@ -5,6 +5,7 @@ use tokio::task::LocalSet;
pub fn start_test(f: impl Future<Output = ()>) {
tokio_test::block_on(async move {
bt_common::init_logging();
+ tokio::time::pause();
LocalSet::new().run_until(f).await;
});
}
diff --git a/system/stack/Android.bp b/system/stack/Android.bp
index fc7212f3fa..c1976a4d8d 100644
--- a/system/stack/Android.bp
+++ b/system/stack/Android.bp
@@ -629,6 +629,7 @@ cc_test {
":TestCommonMockFunctions",
":TestMockStackBtm",
":TestMockStackSdp",
+ ":TestMockStackArbiter",
"gatt/gatt_utils.cc",
"test/common/mock_eatt.cc",
"test/common/mock_gatt_layer.cc",
diff --git a/system/stack/arbiter/acl_arbiter.cc b/system/stack/arbiter/acl_arbiter.cc
index a3f781a253..2dd81dfb51 100644
--- a/system/stack/arbiter/acl_arbiter.cc
+++ b/system/stack/arbiter/acl_arbiter.cc
@@ -47,6 +47,18 @@ class PassthroughAclArbiter : public AclArbiter {
return InterceptAction::FORWARD;
}
+ virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {
+ // no-op
+ }
+
+ virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {
+ // no-op
+ }
+
+ virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {
+ // no-op
+ }
+
static PassthroughAclArbiter& Get() {
static auto singleton = PassthroughAclArbiter();
return singleton;
@@ -59,6 +71,9 @@ struct RustArbiterCallbacks {
::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect;
::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)>
intercept_packet;
+ ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req;
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp;
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req;
};
RustArbiterCallbacks callbacks_{};
@@ -88,6 +103,21 @@ class RustGattAclArbiter : public AclArbiter {
return callbacks_.intercept_packet(tcb_idx, std::move(vec));
}
+ virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {
+ LOG_DEBUG("Notifying Rust of outgoing MTU request");
+ callbacks_.on_outgoing_mtu_req(tcb_idx);
+ }
+
+ virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {
+ LOG_DEBUG("Notifying Rust of incoming MTU response %zu", mtu);
+ callbacks_.on_incoming_mtu_resp(tcb_idx, mtu);
+ }
+
+ virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {
+ LOG_DEBUG("Notifying Rust of incoming MTU request %zu", mtu);
+ callbacks_.on_incoming_mtu_req(tcb_idx, mtu);
+ }
+
void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) {
tGATT_TCB* p_tcb = gatt_get_tcb_by_idx(tcb_idx);
if (p_tcb != nullptr) {
@@ -116,9 +146,13 @@ void StoreCallbacksFromRust(
::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect,
::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect,
::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)>
- intercept_packet) {
+ intercept_packet,
+ ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req,
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp,
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req) {
LOG_INFO("Received callbacks from Rust, registering in Arbiter");
- callbacks_ = {on_le_connect, on_le_disconnect, intercept_packet};
+ callbacks_ = {on_le_connect, on_le_disconnect, intercept_packet,
+ on_outgoing_mtu_req, on_incoming_mtu_resp, on_incoming_mtu_req};
}
void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer) {
diff --git a/system/stack/arbiter/acl_arbiter.h b/system/stack/arbiter/acl_arbiter.h
index 21f4ffba11..485adb6b80 100644
--- a/system/stack/arbiter/acl_arbiter.h
+++ b/system/stack/arbiter/acl_arbiter.h
@@ -44,6 +44,10 @@ class AclArbiter {
virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx,
const BT_HDR* packet) = 0;
+ virtual void OnOutgoingMtuReq(uint8_t tcb_idx) = 0;
+ virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) = 0;
+ virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) = 0;
+
AclArbiter() = default;
AclArbiter(AclArbiter&& other) = default;
AclArbiter& operator=(AclArbiter&& other) = default;
@@ -54,7 +58,10 @@ void StoreCallbacksFromRust(
::rust::Fn<void(uint8_t tcb_idx, uint8_t advertiser)> on_le_connect,
::rust::Fn<void(uint8_t tcb_idx)> on_le_disconnect,
::rust::Fn<InterceptAction(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer)>
- intercept_packet);
+ intercept_packet,
+ ::rust::Fn<void(uint8_t tcb_idx)> on_outgoing_mtu_req,
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_resp,
+ ::rust::Fn<void(uint8_t tcb_idx, size_t mtu)> on_incoming_mtu_req);
void SendPacketToPeer(uint8_t tcb_idx, ::rust::Vec<uint8_t> buffer);
diff --git a/system/stack/gatt/gatt_api.cc b/system/stack/gatt/gatt_api.cc
index 88f1e09687..1d6dff2469 100644
--- a/system/stack/gatt/gatt_api.cc
+++ b/system/stack/gatt/gatt_api.cc
@@ -38,6 +38,7 @@
#include "osi/include/allocator.h"
#include "osi/include/list.h"
#include "osi/include/log.h"
+#include "stack/arbiter/acl_arbiter.h"
#include "stack/btm/btm_dev.h"
#include "stack/gatt/connection_manager.h"
#include "stack/gatt/gatt_int.h"
@@ -718,6 +719,8 @@ tGATT_STATUS GATTC_ConfigureMTU(uint16_t conn_id, uint16_t mtu) {
gatt_cl_msg.mtu = mtu;
LOG_DEBUG("Configuring ATT mtu size conn_id:%hu mtu:%hu", conn_id, mtu);
+ bluetooth::shim::arbiter::GetArbiter().OnOutgoingMtuReq(tcb_idx);
+
return attp_send_cl_msg(*p_clcb->p_tcb, p_clcb, GATT_REQ_MTU, &gatt_cl_msg);
}
diff --git a/system/stack/gatt/gatt_cl.cc b/system/stack/gatt/gatt_cl.cc
index e2f88083c8..e2cfac0038 100644
--- a/system/stack/gatt/gatt_cl.cc
+++ b/system/stack/gatt/gatt_cl.cc
@@ -33,6 +33,7 @@
#include "osi/include/allocator.h"
#include "osi/include/log.h"
#include "osi/include/osi.h"
+#include "stack/arbiter/acl_arbiter.h"
#include "stack/eatt/eatt.h"
#include "stack/include/bt_types.h"
#include "types/bluetooth/uuid.h"
@@ -1102,6 +1103,9 @@ void gatt_process_mtu_rsp(tGATT_TCB& tcb, tGATT_CLCB* p_clcb, uint16_t len,
tcb.payload_size = mtu;
}
+ bluetooth::shim::arbiter::GetArbiter().OnIncomingMtuResp(tcb.tcb_idx,
+ tcb.payload_size);
+
BTM_SetBleDataLength(tcb.peer_bda, tcb.payload_size + L2CAP_PKT_OVERHEAD);
gatt_end_operation(p_clcb, status, NULL);
diff --git a/system/stack/gatt/gatt_sr.cc b/system/stack/gatt/gatt_sr.cc
index 9fefffd5c2..fddb3c2cf3 100644
--- a/system/stack/gatt/gatt_sr.cc
+++ b/system/stack/gatt/gatt_sr.cc
@@ -29,6 +29,7 @@
#include "osi/include/allocator.h"
#include "osi/include/log.h"
#include "osi/include/osi.h"
+#include "stack/arbiter/acl_arbiter.h"
#include "stack/eatt/eatt.h"
#include "stack/include/bt_hdr.h"
#include "stack/include/bt_types.h"
@@ -831,6 +832,9 @@ static void gatts_process_mtu_req(tGATT_TCB& tcb, uint16_t cid, uint16_t len,
attp_build_sr_msg(tcb, GATT_RSP_MTU, &gatt_sr_msg, tcb.payload_size);
attp_send_sr_msg(tcb, cid, p_buf);
+ bluetooth::shim::arbiter::GetArbiter().OnIncomingMtuReq(tcb.tcb_idx,
+ tcb.payload_size);
+
tGATTS_DATA gatts_data;
gatts_data.mtu = tcb.payload_size;
/* Notify all registered applicaiton with new MTU size. Us a transaction ID */
diff --git a/system/test/Android.bp b/system/test/Android.bp
index 3898e69af1..355fa3f0f9 100644
--- a/system/test/Android.bp
+++ b/system/test/Android.bp
@@ -199,6 +199,13 @@ filegroup {
}
filegroup {
+ name: "TestMockStackArbiter",
+ srcs: [
+ "mock/mock_stack_arbiter_*.cc",
+ ],
+}
+
+filegroup {
name: "TestMockStackL2cap",
srcs: [
"mock/mock_stack_l2cap_*.cc",
diff --git a/system/test/mock/mock_stack_arbiter_acl_arbiter.cc b/system/test/mock/mock_stack_arbiter_acl_arbiter.cc
new file mode 100644
index 0000000000..60ee86f1c6
--- /dev/null
+++ b/system/test/mock/mock_stack_arbiter_acl_arbiter.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "stack/arbiter/acl_arbiter.h"
+
+namespace bluetooth {
+namespace shim {
+namespace arbiter {
+
+class MockAclArbiter : public AclArbiter {
+ public:
+ virtual void OnLeConnect(uint8_t tcb_idx, uint16_t advertiser_id) override {}
+
+ virtual void OnLeDisconnect(uint8_t tcb_idx) override {}
+
+ virtual InterceptAction InterceptAttPacket(uint8_t tcb_idx,
+ const BT_HDR* packet) override {
+ return InterceptAction::FORWARD;
+ }
+
+ virtual void OnOutgoingMtuReq(uint8_t tcb_idx) override {}
+
+ virtual void OnIncomingMtuResp(uint8_t tcb_idx, size_t mtu) {}
+
+ virtual void OnIncomingMtuReq(uint8_t tcb_idx, size_t mtu) {}
+
+ static MockAclArbiter& Get() {
+ static auto singleton = MockAclArbiter();
+ return singleton;
+ }
+};
+
+AclArbiter& GetArbiter() {
+ return static_cast<AclArbiter&>(MockAclArbiter::Get());
+}
+
+} // namespace arbiter
+} // namespace shim
+} // namespace bluetooth \ No newline at end of file