adb: clean up transport disconnect operations.
Move operations from global functions into member functions.
Add unit test.
Change-Id: Id4543d8e78541eb08c8e629f180c605c699737ec
diff --git a/adb.cpp b/adb.cpp
index 49cf123..fa935f6 100644
--- a/adb.cpp
+++ b/adb.cpp
@@ -244,11 +244,11 @@
//Close the associated usb
t->online = 0;
- // This is necessary to avoid a race condition that occured when a transport closes
+ // This is necessary to avoid a race condition that occurred when a transport closes
// while a client socket is still active.
close_all_sockets(t);
- run_transport_disconnects(t);
+ t->RunDisconnects();
}
#if DEBUG_PACKETS
diff --git a/adb.h b/adb.h
index 6855f3b..0fb2008 100644
--- a/adb.h
+++ b/adb.h
@@ -157,8 +157,6 @@
{
void (*func)(void* opaque, atransport* t);
void* opaque;
- adisconnect* next;
- adisconnect* prev;
};
diff --git a/adb_auth_client.cpp b/adb_auth_client.cpp
index be28202..c3af024 100644
--- a/adb_auth_client.cpp
+++ b/adb_auth_client.cpp
@@ -47,7 +47,7 @@
static int framework_fd = -1;
static void usb_disconnected(void* unused, atransport* t);
-static struct adisconnect usb_disconnect = { usb_disconnected, 0, 0, 0 };
+static struct adisconnect usb_disconnect = { usb_disconnected, nullptr};
static atransport* usb_transport;
static bool needs_retry = false;
@@ -164,7 +164,6 @@
static void usb_disconnected(void* unused, atransport* t)
{
D("USB disconnect\n");
- remove_transport_disconnect(usb_transport, &usb_disconnect);
usb_transport = NULL;
needs_retry = false;
}
@@ -196,7 +195,7 @@
if (!usb_transport) {
usb_transport = t;
- add_transport_disconnect(t, &usb_disconnect);
+ t->AddDisconnect(&usb_disconnect);
}
if (framework_fd < 0) {
diff --git a/adb_listeners.cpp b/adb_listeners.cpp
index 8fb2d19..d5b1fd5 100644
--- a/adb_listeners.cpp
+++ b/adb_listeners.cpp
@@ -101,13 +101,15 @@
free((char*)l->connect_to);
if (l->transport) {
- remove_transport_disconnect(l->transport, &l->disconnect);
+ l->transport->RemoveDisconnect(&l->disconnect);
}
free(l);
}
-static void listener_disconnect(void* listener, atransport* t) {
- free_listener(reinterpret_cast<alistener*>(listener));
+static void listener_disconnect(void* arg, atransport*) {
+ alistener* listener = reinterpret_cast<alistener*>(arg);
+ listener->transport = nullptr;
+ free_listener(listener);
}
static int local_name_to_fd(const char* name, std::string* error) {
@@ -159,7 +161,7 @@
for (l = listener_list.next; l != &listener_list; l = l->next) {
if (!strcmp(local_name, l->local_name)) {
- listener_disconnect(l, l->transport);
+ free_listener(l);
return INSTALL_STATUS_OK;
}
}
@@ -174,7 +176,7 @@
// Never remove smart sockets.
if (l->connect_to[0] == '*')
continue;
- listener_disconnect(l, l->transport);
+ free_listener(l);
}
}
@@ -209,9 +211,9 @@
free((void*) l->connect_to);
l->connect_to = cto;
if (l->transport != transport) {
- remove_transport_disconnect(l->transport, &l->disconnect);
+ l->transport->RemoveDisconnect(&l->disconnect);
l->transport = transport;
- add_transport_disconnect(l->transport, &l->disconnect);
+ l->transport->AddDisconnect(&l->disconnect);
}
return INSTALL_STATUS_OK;
}
@@ -260,7 +262,7 @@
if (transport) {
listener->disconnect.opaque = listener;
listener->disconnect.func = listener_disconnect;
- add_transport_disconnect(transport, &listener->disconnect);
+ transport->AddDisconnect(&listener->disconnect);
}
return INSTALL_STATUS_OK;
diff --git a/transport.cpp b/transport.cpp
index 4dc5e4a..6ce5d7f 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -42,36 +42,6 @@
ADB_MUTEX_DEFINE( transport_lock );
-// Each atransport contains a list of adisconnects (t->disconnects).
-// An adisconnect contains a link to the next/prev adisconnect, a function
-// pointer to a disconnect callback which takes a void* piece of user data and
-// the atransport, and some user data for the callback (helpfully named
-// "opaque").
-//
-// The list is circular. New items are added to the entry member of the list
-// (t->disconnects) by add_transport_disconnect.
-//
-// run_transport_disconnects invokes each function in the list.
-//
-// Gotchas:
-// * run_transport_disconnects assumes that t->disconnects is non-null, so
-// this can't be run on a zeroed atransport.
-// * The callbacks in this list are not removed when called, and this function
-// is not guarded against running more than once. As such, ensure that this
-// function is not called multiple times on the same atransport.
-// TODO(danalbert): Just fix this so that it is guarded once you have tests.
-void run_transport_disconnects(atransport* t)
-{
- adisconnect* dis = t->disconnects.next;
-
- D("%s: run_transport_disconnects\n", t->serial);
- while (dis != &t->disconnects) {
- adisconnect* next = dis->next;
- dis->func( dis->opaque, t );
- dis = next;
- }
-}
-
static void dump_packet(const char* name, const char* func, apacket* p) {
unsigned command = p->msg.command;
int len = p->msg.data_length;
@@ -588,8 +558,6 @@
transport_list.push_front(t);
adb_mutex_unlock(&transport_lock);
- t->disconnects.next = t->disconnects.prev = &t->disconnects;
-
update_transports();
}
@@ -653,23 +621,6 @@
adb_mutex_unlock(&transport_lock);
}
-void add_transport_disconnect(atransport* t, adisconnect* dis)
-{
- adb_mutex_lock(&transport_lock);
- dis->next = &t->disconnects;
- dis->prev = dis->next->prev;
- dis->prev->next = dis;
- dis->next->prev = dis;
- adb_mutex_unlock(&transport_lock);
-}
-
-void remove_transport_disconnect(atransport* t, adisconnect* dis)
-{
- dis->prev->next = dis->next;
- dis->next->prev = dis->prev;
- dis->next = dis->prev = dis;
-}
-
static int qual_match(const char *to_test,
const char *prefix, const char *qual, bool sanitize_qual)
{
@@ -844,6 +795,21 @@
return has_feature(feature) && supported_features().count(feature) > 0;
}
+void atransport::AddDisconnect(adisconnect* disconnect) {
+ disconnects_.push_back(disconnect);
+}
+
+void atransport::RemoveDisconnect(adisconnect* disconnect) {
+ disconnects_.remove(disconnect);
+}
+
+void atransport::RunDisconnects() {
+ for (auto& disconnect : disconnects_) {
+ disconnect->func(disconnect->opaque, this);
+ }
+ disconnects_.clear();
+}
+
#if ADB_HOST
static void append_transport_info(std::string* result, const char* key,
diff --git a/transport.h b/transport.h
index abb26a7..3b56c55 100644
--- a/transport.h
+++ b/transport.h
@@ -19,6 +19,7 @@
#include <sys/types.h>
+#include <list>
#include <string>
#include <unordered_set>
@@ -71,9 +72,6 @@
int adb_port = -1; // Use for emulators (local transport)
bool kicked = false;
- // A list of adisconnect callbacks called when the transport is kicked.
- adisconnect disconnects = {};
-
void* key = nullptr;
unsigned char token[TOKEN_SIZE] = {};
fdevent auth_fde;
@@ -96,6 +94,10 @@
// feature.
bool CanUseFeature(const std::string& feature) const;
+ void AddDisconnect(adisconnect* disconnect);
+ void RemoveDisconnect(adisconnect* disconnect);
+ void RunDisconnects();
+
private:
// A set of features transmitted in the banner with the initial connection.
// This is stored in the banner as 'features=feature0,feature1,etc'.
@@ -103,6 +105,9 @@
int protocol_version;
size_t max_payload;
+ // A list of adisconnect callbacks called when the transport is kicked.
+ std::list<adisconnect*> disconnects_;
+
DISALLOW_COPY_AND_ASSIGN(atransport);
};
@@ -114,10 +119,7 @@
*/
atransport* acquire_one_transport(ConnectionState state, TransportType type,
const char* serial, std::string* error_out);
-void add_transport_disconnect(atransport* t, adisconnect* dis);
-void remove_transport_disconnect(atransport* t, adisconnect* dis);
void kick_transport(atransport* t);
-void run_transport_disconnects(atransport* t);
void update_transports(void);
void init_transport_registration(void);
diff --git a/transport_test.cpp b/transport_test.cpp
index 743d97d..10872ac 100644
--- a/transport_test.cpp
+++ b/transport_test.cpp
@@ -51,9 +51,6 @@
EXPECT_EQ(adb_port, rhs.adb_port);
EXPECT_EQ(kicked, rhs.kicked);
- EXPECT_EQ(
- 0, memcmp(&disconnects, &rhs.disconnects, sizeof(adisconnect)));
-
EXPECT_EQ(key, rhs.key);
EXPECT_EQ(0, memcmp(token, rhs.token, TOKEN_SIZE));
EXPECT_EQ(0, memcmp(&auth_fde, &rhs.auth_fde, sizeof(fdevent)));
@@ -118,12 +115,33 @@
ASSERT_EQ(expected, t);
}
-// Disabled because the function currently segfaults for a zeroed atransport. I
-// want to make sure I understand how this is working at all before I try fixing
-// that.
-TEST(transport, DISABLED_run_transport_disconnects_zeroed_atransport) {
+static void DisconnectFunc(void* arg, atransport*) {
+ int* count = reinterpret_cast<int*>(arg);
+ ++*count;
+}
+
+TEST(transport, RunDisconnects) {
atransport t;
- run_transport_disconnects(&t);
+ // RunDisconnects() can be called with an empty atransport.
+ t.RunDisconnects();
+
+ int count = 0;
+ adisconnect disconnect;
+ disconnect.func = DisconnectFunc;
+ disconnect.opaque = &count;
+ t.AddDisconnect(&disconnect);
+ t.RunDisconnects();
+ ASSERT_EQ(1, count);
+
+ // disconnect should have been removed automatically.
+ t.RunDisconnects();
+ ASSERT_EQ(1, count);
+
+ count = 0;
+ t.AddDisconnect(&disconnect);
+ t.RemoveDisconnect(&disconnect);
+ t.RunDisconnects();
+ ASSERT_EQ(0, count);
}
TEST(transport, add_feature) {