adb: relax serial matching rules.

Currently targeting a device by serial requires matching the serial
number exactly. This CL relaxes the matching rules for local transports
to ignore protocol prefixes and make the port optional:
  [tcp:|udp:]<hostname>[:port]

The purpose of this is to allow a user to set ANDROID_SERIAL to
something like "tcp:100.100.100.100" and have it work for both fastboot
and adb (assuming the device comes up at 100.100.100.100 in both
modes).

This CL also adds some unit tests for the modified functions to make
sure they work as expected.

Bug: 27340240
Change-Id: I006e0c70c84331ab44d05d0a0f462d06592eb879
diff --git a/socket.h b/socket.h
index 4083036..9eb1b19 100644
--- a/socket.h
+++ b/socket.h
@@ -114,4 +114,13 @@
 void connect_to_remote(asocket *s, const char *destination);
 void connect_to_smartsocket(asocket *s);
 
+// Internal functions that are only made available here for testing purposes.
+namespace internal {
+
+#if ADB_HOST
+char* skip_host_serial(const char* service);
+#endif
+
+}  // namespace internal
+
 #endif  // __ADB_SOCKET_H
diff --git a/socket_test.cpp b/socket_test.cpp
index 471ca09..5cbef6d 100644
--- a/socket_test.cpp
+++ b/socket_test.cpp
@@ -270,3 +270,49 @@
 }
 
 #endif  // defined(__linux__)
+
+#if ADB_HOST
+
+// Checks that skip_host_serial(serial) returns a pointer to the part of |serial| which matches
+// |expected|, otherwise logs the failure to gtest.
+void VerifySkipHostSerial(const std::string& serial, const char* expected) {
+    const char* result = internal::skip_host_serial(serial.c_str());
+    if (expected == nullptr) {
+        EXPECT_EQ(nullptr, result);
+    } else {
+        EXPECT_STREQ(expected, result);
+    }
+}
+
+// Check [tcp:|udp:]<serial>[:<port>]:<command> format.
+TEST(socket_test, test_skip_host_serial) {
+    for (const std::string& protocol : {"", "tcp:", "udp:"}) {
+        VerifySkipHostSerial(protocol, nullptr);
+        VerifySkipHostSerial(protocol + "foo", nullptr);
+
+        VerifySkipHostSerial(protocol + "foo:bar", ":bar");
+        VerifySkipHostSerial(protocol + "foo:bar:baz", ":bar:baz");
+
+        VerifySkipHostSerial(protocol + "foo:123:bar", ":bar");
+        VerifySkipHostSerial(protocol + "foo:123:456", ":456");
+        VerifySkipHostSerial(protocol + "foo:123:bar:baz", ":bar:baz");
+
+        // Don't register a port unless it's all numbers and ends with ':'.
+        VerifySkipHostSerial(protocol + "foo:123", ":123");
+        VerifySkipHostSerial(protocol + "foo:123bar:baz", ":123bar:baz");
+    }
+}
+
+// Check <prefix>:<serial>:<command> format.
+TEST(socket_test, test_skip_host_serial_prefix) {
+    for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) {
+        VerifySkipHostSerial(prefix, nullptr);
+        VerifySkipHostSerial(prefix + "foo", nullptr);
+
+        VerifySkipHostSerial(prefix + "foo:bar", ":bar");
+        VerifySkipHostSerial(prefix + "foo:bar:baz", ":bar:baz");
+        VerifySkipHostSerial(prefix + "foo:123:bar", ":123:bar");
+    }
+}
+
+#endif  // ADB_HOST
diff --git a/sockets.cpp b/sockets.cpp
index d8e4e93..c083ee1 100644
--- a/sockets.cpp
+++ b/sockets.cpp
@@ -26,6 +26,8 @@
 #include <unistd.h>
 
 #include <algorithm>
+#include <string>
+#include <vector>
 
 #if !ADB_HOST
 #include "cutils/properties.h"
@@ -623,43 +625,43 @@
 
 #if ADB_HOST
 
-#define PREFIX(str) { str, sizeof(str) - 1 }
-static const struct prefix_struct {
-    const char *str;
-    const size_t len;
-} prefixes[] = {
-    PREFIX("usb:"),
-    PREFIX("product:"),
-    PREFIX("model:"),
-    PREFIX("device:"),
-};
-static const int num_prefixes = (sizeof(prefixes) / sizeof(prefixes[0]));
+namespace internal {
 
-/* skip_host_serial return the position in a string
-   skipping over the 'serial' parameter in the ADB protocol,
-   where parameter string may be a host:port string containing
-   the protocol delimiter (colon). */
-static char *skip_host_serial(char *service) {
-    char *first_colon, *serial_end;
-    int i;
+// Returns the position in |service| following the target serial parameter. Serial format can be
+// any of:
+//   * [tcp:|udp:]<serial>[:<port>]:<command>
+//   * <prefix>:<serial>:<command>
+// Where <port> must be a base-10 number and <prefix> may be any of {usb,product,model,device}.
+//
+// The returned pointer will point to the ':' just before <command>, or nullptr if not found.
+char* skip_host_serial(const char* service) {
+    static const std::vector<std::string>& prefixes =
+        *(new std::vector<std::string>{"usb:", "product:", "model:", "device:"});
 
-    for (i = 0; i < num_prefixes; i++) {
-        if (!strncmp(service, prefixes[i].str, prefixes[i].len))
-            return strchr(service + prefixes[i].len, ':');
+    for (const std::string& prefix : prefixes) {
+        if (!strncmp(service, prefix.c_str(), prefix.length())) {
+            return strchr(service + prefix.length(), ':');
+        }
     }
 
-    first_colon = strchr(service, ':');
+    // For fastboot compatibility, ignore protocol prefixes.
+    if (!strncmp(service, "tcp:", 4) || !strncmp(service, "udp:", 4)) {
+        service += 4;
+    }
+
+    char* first_colon = strchr(service, ':');
     if (!first_colon) {
-        /* No colon in service string. */
-        return NULL;
+        // No colon in service string.
+        return nullptr;
     }
-    serial_end = first_colon;
+
+    char* serial_end = first_colon;
     if (isdigit(serial_end[1])) {
         serial_end++;
-        while ((*serial_end) && isdigit(*serial_end)) {
+        while (*serial_end && isdigit(*serial_end)) {
             serial_end++;
         }
-        if ((*serial_end) != ':') {
+        if (*serial_end != ':') {
             // Something other than numbers was found, reset the end.
             serial_end = first_colon;
         }
@@ -667,6 +669,8 @@
     return serial_end;
 }
 
+}  // namespace internal
+
 #endif // ADB_HOST
 
 static int smart_socket_enqueue(asocket *s, apacket *p)
@@ -725,7 +729,7 @@
         service += strlen("host-serial:");
 
         // serial number should follow "host:" and could be a host:port string.
-        serial_end = skip_host_serial(service);
+        serial_end = internal::skip_host_serial(service);
         if (serial_end) {
             *serial_end = 0; // terminate string
             serial = service;
diff --git a/transport.cpp b/transport.cpp
index d9180bc..e3340af 100644
--- a/transport.cpp
+++ b/transport.cpp
@@ -30,6 +30,7 @@
 #include <list>
 
 #include <android-base/logging.h>
+#include <android-base/parsenetaddress.h>
 #include <android-base/stringprintf.h>
 #include <android-base/strings.h>
 
@@ -679,11 +680,7 @@
 
         // Check for matching serial number.
         if (serial) {
-            if ((t->serial && !strcmp(serial, t->serial)) ||
-                (t->devpath && !strcmp(serial, t->devpath)) ||
-                qual_match(serial, "product:", t->product, false) ||
-                qual_match(serial, "model:", t->model, true) ||
-                qual_match(serial, "device:", t->device, false)) {
+            if (t->MatchesTarget(serial)) {
                 if (result) {
                     *error_out = "more than one device";
                     if (is_ambiguous) *is_ambiguous = true;
@@ -837,6 +834,43 @@
     disconnects_.clear();
 }
 
+bool atransport::MatchesTarget(const std::string& target) const {
+    if (serial) {
+        if (target == serial) {
+            return true;
+        } else if (type == kTransportLocal) {
+            // Local transports can match [tcp:|udp:]<hostname>[:port].
+            const char* local_target_ptr = target.c_str();
+
+            // For fastboot compatibility, ignore protocol prefixes.
+            if (android::base::StartsWith(target, "tcp:") ||
+                    android::base::StartsWith(target, "udp:")) {
+                local_target_ptr += 4;
+            }
+
+            // Parse our |serial| and the given |target| to check if the hostnames and ports match.
+            std::string serial_host, error;
+            int serial_port = -1;
+            if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr,
+                                               &error)) {
+                // |target| may omit the port to default to ours.
+                std::string target_host;
+                int target_port = serial_port;
+                if (android::base::ParseNetAddress(local_target_ptr, &target_host, &target_port,
+                                                   nullptr, &error) &&
+                        serial_host == target_host && serial_port == target_port) {
+                    return true;
+                }
+            }
+        }
+    }
+
+    return (devpath && target == devpath) ||
+           qual_match(target.c_str(), "product:", product, false) ||
+           qual_match(target.c_str(), "model:", model, true) ||
+           qual_match(target.c_str(), "device:", device, false);
+}
+
 #if ADB_HOST
 
 static void append_transport_info(std::string* result, const char* key,
diff --git a/transport.h b/transport.h
index 4c0c008..5857249 100644
--- a/transport.h
+++ b/transport.h
@@ -107,6 +107,21 @@
     void RemoveDisconnect(adisconnect* disconnect);
     void RunDisconnects();
 
+    // Returns true if |target| matches this transport. A matching |target| can be any of:
+    //   * <serial>
+    //   * <devpath>
+    //   * product:<product>
+    //   * model:<model>
+    //   * device:<device>
+    //
+    // If this is a local transport, serial will also match [tcp:|udp:]<hostname>[:port] targets.
+    // For example, serial "100.100.100.100:5555" would match any of:
+    //   * 100.100.100.100
+    //   * tcp:100.100.100.100
+    //   * udp:100.100.100.100:5555
+    // This is to make it easier to use the same network target for both fastboot and adb.
+    bool MatchesTarget(const std::string& target) const;
+
 private:
     // A set of features transmitted in the banner with the initial connection.
     // This is stored in the banner as 'features=feature0,feature1,etc'.
diff --git a/transport_test.cpp b/transport_test.cpp
index 1bdea2a..2028ecc 100644
--- a/transport_test.cpp
+++ b/transport_test.cpp
@@ -218,3 +218,60 @@
     ASSERT_EQ(std::string("bar"), t.model);
     ASSERT_EQ(std::string("baz"), t.device);
 }
+
+TEST(transport, test_matches_target) {
+    std::string serial = "foo";
+    std::string devpath = "/path/to/bar";
+    std::string product = "test_product";
+    std::string model = "test_model";
+    std::string device = "test_device";
+
+    atransport t;
+    t.serial = &serial[0];
+    t.devpath = &devpath[0];
+    t.product = &product[0];
+    t.model = &model[0];
+    t.device = &device[0];
+
+    // These tests should not be affected by the transport type.
+    for (TransportType type : {kTransportAny, kTransportLocal}) {
+        t.type = type;
+
+        EXPECT_TRUE(t.MatchesTarget(serial));
+        EXPECT_TRUE(t.MatchesTarget(devpath));
+        EXPECT_TRUE(t.MatchesTarget("product:" + product));
+        EXPECT_TRUE(t.MatchesTarget("model:" + model));
+        EXPECT_TRUE(t.MatchesTarget("device:" + device));
+
+        // Product, model, and device don't match without the prefix.
+        EXPECT_FALSE(t.MatchesTarget(product));
+        EXPECT_FALSE(t.MatchesTarget(model));
+        EXPECT_FALSE(t.MatchesTarget(device));
+    }
+}
+
+TEST(transport, test_matches_target_local) {
+    std::string serial = "100.100.100.100:5555";
+
+    atransport t;
+    t.serial = &serial[0];
+
+    // Network address matching should only be used for local transports.
+    for (TransportType type : {kTransportAny, kTransportLocal}) {
+        t.type = type;
+        bool should_match = (type == kTransportLocal);
+
+        EXPECT_EQ(should_match, t.MatchesTarget("100.100.100.100"));
+        EXPECT_EQ(should_match, t.MatchesTarget("tcp:100.100.100.100"));
+        EXPECT_EQ(should_match, t.MatchesTarget("tcp:100.100.100.100:5555"));
+        EXPECT_EQ(should_match, t.MatchesTarget("udp:100.100.100.100"));
+        EXPECT_EQ(should_match, t.MatchesTarget("udp:100.100.100.100:5555"));
+
+        // Wrong protocol, hostname, or port should never match.
+        EXPECT_FALSE(t.MatchesTarget("100.100.100"));
+        EXPECT_FALSE(t.MatchesTarget("100.100.100.100:"));
+        EXPECT_FALSE(t.MatchesTarget("100.100.100.100:-1"));
+        EXPECT_FALSE(t.MatchesTarget("100.100.100.100:5554"));
+        EXPECT_FALSE(t.MatchesTarget("abc:100.100.100.100"));
+    }
+}