summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author Kiran S <krns@google.com> 2024-05-14 03:50:41 +0000
committer Kiran S <krns@google.com> 2024-05-14 11:02:17 +0000
commit32e94e57abfd71dd0ec560416c6c0c8b2e18acae (patch)
treeb9d128b13c0096fed25daf5164be2b701a924421
parentcdf3a13551c2950d1400fedf77502c310d177f9b (diff)
Update the USB device filter to support interface name
The interface name will be used as an additional parameter while filtering devices based on interface class, subclass and protocol. This will be used to correctly filter MTP devices based on interface. The current logic of only using class, subclass and protocol will incorrectly identify some devices interfaces as supporting MTP Test: atest UsbManagerTests, Manually verified that interface name is used for filtering MTP devices correctly Bug: 312828160 Flag: android.hardware.usb.flags.enable_interface_name_device_filter Change-Id: I7f6538864a0464cf0fedf677ef075ee98ef02b15
-rw-r--r--core/java/android/hardware/usb/DeviceFilter.java48
-rw-r--r--core/java/android/hardware/usb/flags/usb_framework_flags.aconfig8
-rw-r--r--tests/UsbManagerTests/src/android/hardware/usb/DeviceFilterTest.java248
3 files changed, 291 insertions, 13 deletions
diff --git a/core/java/android/hardware/usb/DeviceFilter.java b/core/java/android/hardware/usb/DeviceFilter.java
index 66b0a426f35d..3a271b44eef2 100644
--- a/core/java/android/hardware/usb/DeviceFilter.java
+++ b/core/java/android/hardware/usb/DeviceFilter.java
@@ -18,6 +18,7 @@ package android.hardware.usb;
import android.annotation.NonNull;
import android.annotation.Nullable;
+import android.hardware.usb.flags.Flags;
import android.service.usb.UsbDeviceFilterProto;
import android.util.Slog;
@@ -57,9 +58,12 @@ public class DeviceFilter {
public final String mProductName;
// USB device serial number string (or null for unspecified)
public final String mSerialNumber;
+ // USB interface name (or null for unspecified). This will be used when matching devices using
+ // the available interfaces.
+ public final String mInterfaceName;
public DeviceFilter(int vid, int pid, int clasz, int subclass, int protocol,
- String manufacturer, String product, String serialnum) {
+ String manufacturer, String product, String serialnum, String interfaceName) {
mVendorId = vid;
mProductId = pid;
mClass = clasz;
@@ -68,6 +72,7 @@ public class DeviceFilter {
mManufacturerName = manufacturer;
mProductName = product;
mSerialNumber = serialnum;
+ mInterfaceName = interfaceName;
}
public DeviceFilter(UsbDevice device) {
@@ -79,6 +84,7 @@ public class DeviceFilter {
mManufacturerName = device.getManufacturerName();
mProductName = device.getProductName();
mSerialNumber = device.getSerialNumber();
+ mInterfaceName = null;
}
public DeviceFilter(@NonNull DeviceFilter filter) {
@@ -90,6 +96,7 @@ public class DeviceFilter {
mManufacturerName = filter.mManufacturerName;
mProductName = filter.mProductName;
mSerialNumber = filter.mSerialNumber;
+ mInterfaceName = filter.mInterfaceName;
}
public static DeviceFilter read(XmlPullParser parser)
@@ -102,7 +109,7 @@ public class DeviceFilter {
String manufacturerName = null;
String productName = null;
String serialNumber = null;
-
+ String interfaceName = null;
int count = parser.getAttributeCount();
for (int i = 0; i < count; i++) {
String name = parser.getAttributeName(i);
@@ -114,6 +121,8 @@ public class DeviceFilter {
productName = value;
} else if ("serial-number".equals(name)) {
serialNumber = value;
+ } else if ("interface-name".equals(name)) {
+ interfaceName = value;
} else {
int intValue;
int radix = 10;
@@ -144,7 +153,7 @@ public class DeviceFilter {
}
return new DeviceFilter(vendorId, productId,
deviceClass, deviceSubclass, deviceProtocol,
- manufacturerName, productName, serialNumber);
+ manufacturerName, productName, serialNumber, interfaceName);
}
public void write(XmlSerializer serializer) throws IOException {
@@ -173,13 +182,25 @@ public class DeviceFilter {
if (mSerialNumber != null) {
serializer.attribute(null, "serial-number", mSerialNumber);
}
+ if (mInterfaceName != null) {
+ serializer.attribute(null, "interface-name", mInterfaceName);
+ }
serializer.endTag(null, "usb-device");
}
- private boolean matches(int clasz, int subclass, int protocol) {
- return ((mClass == -1 || clasz == mClass) &&
- (mSubclass == -1 || subclass == mSubclass) &&
- (mProtocol == -1 || protocol == mProtocol));
+ private boolean matches(int usbClass, int subclass, int protocol) {
+ return ((mClass == -1 || usbClass == mClass)
+ && (mSubclass == -1 || subclass == mSubclass)
+ && (mProtocol == -1 || protocol == mProtocol));
+ }
+
+ private boolean matches(int usbClass, int subclass, int protocol, String interfaceName) {
+ if (Flags.enableInterfaceNameDeviceFilter()) {
+ return matches(usbClass, subclass, protocol)
+ && (mInterfaceName == null || mInterfaceName.equals(interfaceName));
+ } else {
+ return matches(usbClass, subclass, protocol);
+ }
}
public boolean matches(UsbDevice device) {
@@ -204,7 +225,7 @@ public class DeviceFilter {
for (int i = 0; i < count; i++) {
UsbInterface intf = device.getInterface(i);
if (matches(intf.getInterfaceClass(), intf.getInterfaceSubclass(),
- intf.getInterfaceProtocol())) return true;
+ intf.getInterfaceProtocol(), intf.getName())) return true;
}
return false;
@@ -320,11 +341,12 @@ public class DeviceFilter {
@Override
public String toString() {
- return "DeviceFilter[mVendorId=" + mVendorId + ",mProductId=" + mProductId +
- ",mClass=" + mClass + ",mSubclass=" + mSubclass +
- ",mProtocol=" + mProtocol + ",mManufacturerName=" + mManufacturerName +
- ",mProductName=" + mProductName + ",mSerialNumber=" + mSerialNumber +
- "]";
+ return "DeviceFilter[mVendorId=" + mVendorId + ",mProductId=" + mProductId
+ + ",mClass=" + mClass + ",mSubclass=" + mSubclass
+ + ",mProtocol=" + mProtocol + ",mManufacturerName=" + mManufacturerName
+ + ",mProductName=" + mProductName + ",mSerialNumber=" + mSerialNumber
+ + ",mInterfaceName=" + mInterfaceName
+ + "]";
}
/**
diff --git a/core/java/android/hardware/usb/flags/usb_framework_flags.aconfig b/core/java/android/hardware/usb/flags/usb_framework_flags.aconfig
index 94df16030cdb..40e5ffb141ab 100644
--- a/core/java/android/hardware/usb/flags/usb_framework_flags.aconfig
+++ b/core/java/android/hardware/usb/flags/usb_framework_flags.aconfig
@@ -16,3 +16,11 @@ flag {
description: "Feature flag for the api to check if a port supports mode change"
bug: "323470419"
}
+
+flag {
+ name: "enable_interface_name_device_filter"
+ is_exported: true
+ namespace: "usb"
+ description: "Feature flag to enable interface name as a parameter for device filter"
+ bug: "312828160"
+}
diff --git a/tests/UsbManagerTests/src/android/hardware/usb/DeviceFilterTest.java b/tests/UsbManagerTests/src/android/hardware/usb/DeviceFilterTest.java
new file mode 100644
index 000000000000..d6f3148e64f1
--- /dev/null
+++ b/tests/UsbManagerTests/src/android/hardware/usb/DeviceFilterTest.java
@@ -0,0 +1,248 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.hardware.usb;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static junit.framework.Assert.assertFalse;
+import static junit.framework.Assert.assertTrue;
+
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import android.hardware.usb.flags.Flags;
+
+import androidx.test.runner.AndroidJUnit4;
+
+import com.android.dx.mockito.inline.extended.ExtendedMockito;
+import com.android.internal.util.XmlUtils;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mockito;
+import org.mockito.MockitoSession;
+import org.mockito.quality.Strictness;
+import org.xmlpull.v1.XmlPullParser;
+import org.xmlpull.v1.XmlPullParserFactory;
+import org.xmlpull.v1.XmlSerializer;
+
+import java.io.StringReader;
+
+/**
+ * Unit tests for {@link android.hardware.usb.DeviceFilter}.
+ */
+@RunWith(AndroidJUnit4.class)
+public class DeviceFilterTest {
+
+ private static final int VID = 10;
+ private static final int PID = 11;
+ private static final int CLASS = 12;
+ private static final int SUBCLASS = 13;
+ private static final int PROTOCOL = 14;
+ private static final String MANUFACTURER = "Google";
+ private static final String PRODUCT = "Test";
+ private static final String SERIAL_NO = "4AL23";
+ private static final String INTERFACE_NAME = "MTP";
+
+ private MockitoSession mStaticMockSession;
+
+ @Before
+ public void setUp() throws Exception {
+ mStaticMockSession = ExtendedMockito.mockitoSession()
+ .mockStatic(Flags.class)
+ .strictness(Strictness.WARN)
+ .startMocking();
+
+ when(Flags.enableInterfaceNameDeviceFilter()).thenReturn(true);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ mStaticMockSession.finishMocking();
+ }
+
+ @Test
+ public void testConstructorFromValues_interfaceNameIsInitialized() {
+ DeviceFilter deviceFilter = new DeviceFilter(
+ VID, PID, CLASS, SUBCLASS, PROTOCOL, MANUFACTURER,
+ PRODUCT, SERIAL_NO, INTERFACE_NAME
+ );
+
+ verifyDeviceFilterConfigurationExceptInterfaceName(deviceFilter);
+ assertThat(deviceFilter.mInterfaceName).isEqualTo(INTERFACE_NAME);
+ }
+
+ @Test
+ public void testConstructorFromUsbDevice_interfaceNameIsNull() {
+ UsbDevice usbDevice = Mockito.mock(UsbDevice.class);
+ when(usbDevice.getVendorId()).thenReturn(VID);
+ when(usbDevice.getProductId()).thenReturn(PID);
+ when(usbDevice.getDeviceClass()).thenReturn(CLASS);
+ when(usbDevice.getDeviceSubclass()).thenReturn(SUBCLASS);
+ when(usbDevice.getDeviceProtocol()).thenReturn(PROTOCOL);
+ when(usbDevice.getManufacturerName()).thenReturn(MANUFACTURER);
+ when(usbDevice.getProductName()).thenReturn(PRODUCT);
+ when(usbDevice.getSerialNumber()).thenReturn(SERIAL_NO);
+
+ DeviceFilter deviceFilter = new DeviceFilter(usbDevice);
+
+ verifyDeviceFilterConfigurationExceptInterfaceName(deviceFilter);
+ assertThat(deviceFilter.mInterfaceName).isEqualTo(null);
+ }
+
+ @Test
+ public void testConstructorFromDeviceFilter_interfaceNameIsInitialized() {
+ DeviceFilter originalDeviceFilter = new DeviceFilter(
+ VID, PID, CLASS, SUBCLASS, PROTOCOL, MANUFACTURER,
+ PRODUCT, SERIAL_NO, INTERFACE_NAME
+ );
+
+ DeviceFilter deviceFilter = new DeviceFilter(originalDeviceFilter);
+
+ verifyDeviceFilterConfigurationExceptInterfaceName(deviceFilter);
+ assertThat(deviceFilter.mInterfaceName).isEqualTo(INTERFACE_NAME);
+ }
+
+
+ @Test
+ public void testReadFromXml_interfaceNamePresent_propertyIsInitialized() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml("<usb-device interface-name=\"MTP\"/>");
+
+ assertThat(deviceFilter.mInterfaceName).isEqualTo("MTP");
+ }
+
+ @Test
+ public void testReadFromXml_interfaceNameAbsent_propertyIsNull() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml("<usb-device vendor-id=\"1\" />");
+
+ assertThat(deviceFilter.mInterfaceName).isEqualTo(null);
+ }
+
+ @Test
+ public void testWrite_withInterfaceName() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml("<usb-device interface-name=\"MTP\"/>");
+ XmlSerializer serializer = Mockito.mock(XmlSerializer.class);
+
+ deviceFilter.write(serializer);
+
+ verify(serializer).attribute(null, "interface-name", "MTP");
+ }
+
+ @Test
+ public void testWrite_withoutInterfaceName() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml("<usb-device vendor-id=\"1\" />");
+ XmlSerializer serializer = Mockito.mock(XmlSerializer.class);
+
+ deviceFilter.write(serializer);
+
+ verify(serializer, times(0)).attribute(eq(null), eq("interface-name"), any());
+ }
+
+ @Test
+ public void testToString() {
+ DeviceFilter deviceFilter = new DeviceFilter(
+ VID, PID, CLASS, SUBCLASS, PROTOCOL, MANUFACTURER,
+ PRODUCT, SERIAL_NO, INTERFACE_NAME
+ );
+
+ assertThat(deviceFilter.toString()).isEqualTo(
+ "DeviceFilter[mVendorId=10,mProductId=11,mClass=12,mSubclass=13,mProtocol=14,"
+ + "mManufacturerName=Google,mProductName=Test,mSerialNumber=4AL23,"
+ + "mInterfaceName=MTP]");
+ }
+
+ @Test
+ public void testMatch_interfaceNameMatches_returnTrue() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml(
+ "<usb-device class=\"255\" subclass=\"255\" protocol=\"0\" "
+ + "interface-name=\"MTP\"/>");
+ UsbDevice usbDevice = Mockito.mock(UsbDevice.class);
+ when(usbDevice.getInterfaceCount()).thenReturn(1);
+ when(usbDevice.getInterface(0)).thenReturn(new UsbInterface(
+ /* id= */ 0,
+ /* alternateSetting= */ 0,
+ /* name= */ "MTP",
+ /* class= */ 255,
+ /* subClass= */ 255,
+ /* protocol= */ 0));
+
+ assertTrue(deviceFilter.matches(usbDevice));
+ }
+
+ @Test
+ public void testMatch_interfaceNameMismatch_returnFalse() throws Exception {
+ DeviceFilter deviceFilter = getDeviceFilterFromXml(
+ "<usb-device class=\"255\" subclass=\"255\" protocol=\"0\" "
+ + "interface-name=\"MTP\"/>");
+ UsbDevice usbDevice = Mockito.mock(UsbDevice.class);
+ when(usbDevice.getInterfaceCount()).thenReturn(1);
+ when(usbDevice.getInterface(0)).thenReturn(new UsbInterface(
+ /* id= */ 0,
+ /* alternateSetting= */ 0,
+ /* name= */ "UVC",
+ /* class= */ 255,
+ /* subClass= */ 255,
+ /* protocol= */ 0));
+
+ assertFalse(deviceFilter.matches(usbDevice));
+ }
+
+ @Test
+ public void testMatch_interfaceNameMismatchFlagDisabled_returnTrue() throws Exception {
+ when(Flags.enableInterfaceNameDeviceFilter()).thenReturn(false);
+ DeviceFilter deviceFilter = getDeviceFilterFromXml(
+ "<usb-device class=\"255\" subclass=\"255\" protocol=\"0\" "
+ + "interface-name=\"MTP\"/>");
+ UsbDevice usbDevice = Mockito.mock(UsbDevice.class);
+ when(usbDevice.getInterfaceCount()).thenReturn(1);
+ when(usbDevice.getInterface(0)).thenReturn(new UsbInterface(
+ /* id= */ 0,
+ /* alternateSetting= */ 0,
+ /* name= */ "UVC",
+ /* class= */ 255,
+ /* subClass= */ 255,
+ /* protocol= */ 0));
+
+ assertTrue(deviceFilter.matches(usbDevice));
+ }
+
+ private void verifyDeviceFilterConfigurationExceptInterfaceName(DeviceFilter deviceFilter) {
+ assertThat(deviceFilter.mVendorId).isEqualTo(VID);
+ assertThat(deviceFilter.mProductId).isEqualTo(PID);
+ assertThat(deviceFilter.mClass).isEqualTo(CLASS);
+ assertThat(deviceFilter.mSubclass).isEqualTo(SUBCLASS);
+ assertThat(deviceFilter.mProtocol).isEqualTo(PROTOCOL);
+ assertThat(deviceFilter.mManufacturerName).isEqualTo(MANUFACTURER);
+ assertThat(deviceFilter.mProductName).isEqualTo(PRODUCT);
+ assertThat(deviceFilter.mSerialNumber).isEqualTo(SERIAL_NO);
+ }
+
+ private DeviceFilter getDeviceFilterFromXml(String xml) throws Exception {
+ XmlPullParserFactory factory = XmlPullParserFactory.newInstance();
+ XmlPullParser parser = factory.newPullParser();
+ parser.setInput(new StringReader(xml));
+ XmlUtils.nextElement(parser);
+
+ return DeviceFilter.read(parser);
+ }
+
+}