diff --git a/src/app/tests/suites/certification/Test_TC_SC_4_3.yaml b/src/app/tests/suites/certification/Test_TC_SC_4_3.yaml deleted file mode 100644 index 8521cb24bc8daa..00000000000000 --- a/src/app/tests/suites/certification/Test_TC_SC_4_3.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2021 Project CHIP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Auto-generated scripts for harness use only, please review before automation. The endpoints and cluster names are currently set to default - -name: 3.4.3. [TC-SC-4.3] Discovery [DUT as Commissionee] - -PICS: - - MCORE.ROLE.COMMISSIONEE - -config: - nodeId: 0x12344321 - cluster: "Basic Information" - endpoint: 0 - -tests: - - label: "Note" - verification: | - Chip-tool command used below are an example to verify the DUT as controller test cases. For certification test, we expect DUT should have a capability or way to run the equivalent command. - disabled: true - - - label: - "Step 1: DUT is commissioned to TH and DUT is instructed to advertise - its operational service (_matter._tcp)." - verification: | - 1. Provision the DUT by TH (Chip-tool) - disabled: true - - - label: "Step 2: Scan for DNS-SD advertising" - PICS: - MCORE.COM.WIFI && MCORE.COM.ETH && MCORE.COM.THR && MCORE.ICD && - MCORE.SC.SII_OP_DISCOVERY_KEY && MCORE.SC.SAI_OP_DISCOVERY_KEY && - MCORE.SC.SAT_OP_DISCOVERY_KEY && MCORE.SC.T_KEY - verification: | - avahi-browse -rt _matter._tcp - - Verify the device is advertising _matterc._udp service like below output in TH terminal Log: (Verify for the DUT's actual values (like vendor ID, PID ..etc) as mentioned in the expected outcome of the test plan, The below log contains the data from the reference raspi accessory) - - + veth721e1d9 IPv6 433B62F8F07F4327-0000000000000001 _matter._tcp local - = veth721e1d9 IPv6 433B62F8F07F4327-0000000000000001 _matter._tcp local - hostname = [E45F0149AE290000.local] - address = [fe80::28e0:95ff:fed9:3085] - port = [5540] - txt = ["T=1" "SAI=300" "SII=5000"] - disabled: true diff --git a/src/app/tests/suites/manualTests.json b/src/app/tests/suites/manualTests.json index 0e6e045b8ea7f6..44115f102790e4 100644 --- a/src/app/tests/suites/manualTests.json +++ b/src/app/tests/suites/manualTests.json @@ -225,7 +225,6 @@ "Test_TC_SC_3_4", "Test_TC_SC_4_1", "Test_TC_SC_4_2", - "Test_TC_SC_4_3", "Test_TC_SC_4_4", "Test_TC_SC_4_5", "Test_TC_SC_4_6", diff --git a/src/python_testing/TC_SC_4_3.py b/src/python_testing/TC_SC_4_3.py new file mode 100644 index 00000000000000..e0ce02d2a58398 --- /dev/null +++ b/src/python_testing/TC_SC_4_3.py @@ -0,0 +1,407 @@ +# +# Copyright (c) 2024 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# See https://github.com/project-chip/connectedhomeip/blob/master/docs/testing/python.md#defining-the-ci-test-arguments +# for details about the block below. + +# === BEGIN CI TEST ARGUMENTS === +# test-runner-runs: +# run1: +# app: ${LIT_ICD_APP} +# factory-reset: true +# quiet: true +# app-args: --discriminator 1234 --KVS kvs1 --trace-to json:${TRACE_APP}.json +# script-args: > +# --storage-path admin_storage.json +# --commissioning-method on-network +# --discriminator 1234 +# --passcode 20202021 +# --trace-to json:${TRACE_TEST_JSON}.json +# --trace-to perfetto:${TRACE_TEST_PERFETTO}.perfetto +# === END CI TEST ARGUMENTS === + +import ipaddress +import logging +import re + +import chip.clusters as Clusters +from chip.testing.matter_testing import MatterBaseTest, TestStep, async_test_body, default_matter_test_main +from mdns_discovery.mdns_discovery import DNSRecordType, MdnsDiscovery, MdnsServiceType +from mobly import asserts +from zeroconf.const import _TYPE_AAAA, _TYPES + +''' +Purpose +The purpose of this test case is to verify that a Matter node is discoverable +and can advertise its services in a Matter network. + +Test Plan +https://github.com/CHIP-Specifications/chip-test-plans/blob/master/src/securechannel.adoc#343-tc-sc-43-discovery-dut_commissionee +''' + + +class TC_SC_4_3(MatterBaseTest): + + def steps_TC_SC_4_3(self): + return [TestStep(1, "DUT is commissioned on the same fabric as TH."), + TestStep(2, "TH reads ServerList attribute from the Descriptor cluster on EP0. ", + "If the ICD Management cluster ID (70,0x46) is present in the list, set supports_icd to true, otherwise set supports_icd to false."), + TestStep(3, + "If supports_icd is true, TH reads ActiveModeThreshold from the ICD Management cluster on EP0 and saves as active_mode_threshold.", ""), + TestStep(4, "If supports_icd is true, TH reads FeatureMap from the ICD Management cluster on EP0. If the LITS feature is set, set supports_lit to true. Otherwise set supports_lit to false.", ""), + TestStep(5, "TH constructs the instance name for the DUT as the 64-bit compressed Fabric identifier, and the assigned 64-bit Node identifier, each expressed as a fixed-length sixteen-character hexadecimal string, encoded as ASCII (UTF-8) text using capital letters, separated by a hyphen.", ""), + TestStep(6, "TH performs a query for the SRV record against the qname instance_qname.", + "Verify SRV record is returned"), + TestStep(7, "TH performs a query for the TXT record against the qname instance_qname.", + "Verify TXT record is returned"), + TestStep(8, "TH performs a query for the AAAA record against the target listed in the SRV record", + "Verify AAAA record is returned"), + TestStep(9, "TH verifies the following from the returned records:", + "TH verifies the following from the returned records: The hostname must be a fixed-length twelve-character (or sixteen-character) hexadecimal string, encoded as ASCII (UTF-8) text using capital letters.. ICD TXT key: • If supports_lit is false, verify that the ICD key is NOT present in the TXT record • If supports_lit is true, verify the ICD key IS present in the TXT record, and it has the value of 0 or 1 (ASCII) SII TXT key: • If supports_icd is true and supports_lit is false, set sit_mode to true • If supports_icd is true and supports_lit is true, set sit_mode to true if ICD=0 otherwise set sit_mode to false • If supports_icd is false, set sit_mode to false • If sit_mode is true, verify that the SII key IS present in the TXT record • if the SII key is present, verify it is a decimal value with no leading zeros and is less than or equal to 3600000 (1h in ms) SAI TXT key: • if supports_icd is true, verify that the SAI key is present in the TXT record • If the SAI key is present, verify it is a decimal value with no leading zeros and is less than or equal to 3600000 (1h in ms)"), + TestStep(10, "TH performs a DNS-SD browse for _I._sub._matter._tcp.local, where is the 64-bit compressed Fabric identifier, expressed as a fixed-length, sixteencharacter hexadecimal string, encoded as ASCII (UTF-8) text using capital letters.", + "Verify DUT returns a PTR record with DNS-SD instance name set to instance_name"), + TestStep(11, "TH performs a DNS-SD browse for _matter._tcp.local", + "Verify DUT returns a PTR record with DNS-SD instance name set to instance_name"), + ] + + ONE_HOUR_IN_MS = 3600000 + MAX_SAT_VALUE = 65535 + MAX_T_VALUE = 6 + + async def get_descriptor_server_list(self): + return await self.read_single_attribute_check_success( + endpoint=0, + dev_ctrl=self.default_controller, + cluster=Clusters.Descriptor, + attribute=Clusters.Descriptor.Attributes.ServerList + ) + + async def get_idle_mode_threshhold_ms(self): + return await self.read_single_attribute_check_success( + endpoint=0, + dev_ctrl=self.default_controller, + cluster=Clusters.IcdManagement, + attribute=Clusters.IcdManagement.Attributes.ActiveModeThreshold + ) + + async def get_icd_feature_map(self): + return await self.read_single_attribute_check_success( + endpoint=0, + dev_ctrl=self.default_controller, + cluster=Clusters.IcdManagement, + attribute=Clusters.IcdManagement.Attributes.FeatureMap + ) + + def get_dut_instance_name(self) -> str: + node_id = self.dut_node_id + compressed_fabric_id = self.default_controller.GetCompressedFabricId() + instance_name = f'{compressed_fabric_id:016X}-{node_id:016X}' + return instance_name + + def get_operational_subtype(self) -> str: + compressed_fabric_id = self.default_controller.GetCompressedFabricId() + service_name = f'_I{compressed_fabric_id:016X}._sub.{MdnsServiceType.OPERATIONAL.value}' + return service_name + + @staticmethod + def verify_decimal_value(input_value, max_value: int): + try: + input_float = float(input_value) + input_int = int(input_float) + + if str(input_value).startswith("0") and input_int != 0: + return (False, f"Input ({input_value}) has leading zeros.") + + if input_float != input_int: + return (False, f"Input ({input_value}) is not an integer.") + + if input_int <= max_value: + return (True, f"Input ({input_value}) is valid.") + else: + return (False, f"Input ({input_value}) exceeds the allowed value {max_value}.") + except ValueError: + return (False, f"Input ({input_value}) is not a valid decimal number.") + + def verify_t_value(self, t_value): + # Verify t_value is a decimal number without leading zeros and less than or equal to 6 + try: + T_int = int(t_value) + if T_int < 0 or T_int > self.MAX_T_VALUE: + return False, f"T value ({t_value}) is not in the range 0 to 6. ({t_value})" + if str(t_value).startswith("0") and T_int != 0: + return False, f"T value ({t_value}) has leading zeros." + if T_int != float(t_value): + return False, f"T value ({t_value}) is not an integer." + + # Convert to bitmap and verify bit 0 is clear + if T_int & 1 == 0: + return True, f"T value ({t_value}) is valid and bit 0 is clear." + else: + return False, f"Bit 0 is not clear. T value ({t_value})" + except ValueError: + return False, f"T value ({t_value}) is not a valid integer" + + @staticmethod + def is_ipv6_address(addresses): + if isinstance(addresses, str): + addresses = [addresses] + + for address in addresses: + try: + # Attempt to create an IPv6 address object. If successful, this is an IPv6 address. + ipaddress.IPv6Address(address) + return True, "At least one IPv6 address is present." + except ipaddress.AddressValueError: + # If an AddressValueError is raised, the current address is not a valid IPv6 address. + return False, f"Invalid IPv6 address encountered: {address}, provided addresses: {addresses}" + return False, "No IPv6 addresses found." + + @staticmethod + def extract_ipv6_address(text): + items = text.split(',') + return items[-1] + + @staticmethod + def verify_hostname(hostname: str) -> bool: + # Remove any trailing dot + if hostname.endswith('.'): + hostname = hostname[:-1] + + # Remove '.local' suffix if present + if hostname.endswith('.local'): + hostname = hostname[:-6] + + # Regular expression to match an uppercase hexadecimal string of 12 or 16 characters + pattern = re.compile(r'^[0-9A-F]{12}$|^[0-9A-F]{16}$') + return bool(pattern.match(hostname)) + + @async_test_body + async def test_TC_SC_4_3(self): + supports_icd = None + supports_lit = None + active_mode_threshold_ms = None + instance_name = None + + # *** STEP 1 *** + # DUT is commissioned on the same fabric as TH. + self.step(1) + + # *** STEP 2 *** + # TH reads from the DUT the ServerList attribute from the Descriptor cluster on EP0. If the ICD Management + # cluster ID (70,0x46) is present in the list, set supports_icd to true, otherwise set supports_icd to false. + self.step(2) + ep0_servers = await self.get_descriptor_server_list() + + # Check if ep0_servers contain the ICD Management cluster ID (0x0046) + supports_icd = Clusters.IcdManagement.id in ep0_servers + logging.info(f"supports_icd: {supports_icd}") + + # *** STEP 3 *** + # If supports_icd is true, TH reads ActiveModeThreshold from the ICD Management cluster on EP0 and saves + # as active_mode_threshold. + self.step(3) + if supports_icd: + active_mode_threshold_ms = await self.get_idle_mode_threshhold_ms() + logging.info(f"active_mode_threshold_ms: {active_mode_threshold_ms}") + + # *** STEP 4 *** + # If supports_icd is true, TH reads FeatureMap from the ICD Management cluster on EP0. If the LITS feature + # is set, set supports_lit to true. Otherwise set supports_lit to false. + self.step(4) + if supports_icd: + feature_map = await self.get_icd_feature_map() + LITS = Clusters.IcdManagement.Bitmaps.Feature.kLongIdleTimeSupport + supports_lit = bool(feature_map & LITS == LITS) + logging.info(f"kLongIdleTimeSupport set: {supports_lit}") + + # *** STEP 5 *** + # TH constructs the instance name for the DUT as the 64-bit compressed Fabric identifier, and the + # assigned 64-bit Node identifier, each expressed as a fixed-length sixteen-character hexadecimal + # string, encoded as ASCII (UTF-8) text using capital letters, separated by a hyphen. + self.step(5) + instance_name = self.get_dut_instance_name() + instance_qname = f"{instance_name}.{MdnsServiceType.OPERATIONAL.value}" + + # *** STEP 6 *** + # TH performs a query for the SRV record against the qname instance_qname. + # Verify SRV record is returned + self.step(6) + mdns = MdnsDiscovery() + operational_record = await mdns.get_service_by_record_type( + service_name=instance_qname, + service_type=MdnsServiceType.OPERATIONAL.value, + record_type=DNSRecordType.SRV, + log_output=True + ) + + # Will be used in Step 8 and 11 + # This is the hostname + hostname = operational_record.server + + # Verify SRV record is returned + srv_record_returned = operational_record is not None and operational_record.service_name == instance_qname + asserts.assert_true(srv_record_returned, "SRV record was not returned") + + # *** STEP 7 *** + # TH performs a query for the TXT record against the qname instance_qname. + # Verify TXT record is returned + self.step(7) + operational_record = await mdns.get_service_by_record_type( + service_name=instance_qname, + service_type=MdnsServiceType.OPERATIONAL.value, + record_type=DNSRecordType.TXT, + log_output=True + ) + + # Verify TXT record is returned and it contains values + txt_record_returned = operational_record.txt_record is not None and bool(operational_record.txt_record) + asserts.assert_true(txt_record_returned, "TXT record not returned or contains no values") + + # *** STEP 8 *** + # TH performs a query for the AAAA record against the target listed in the SRV record. + # Verify AAAA record is returned + self.step(8) + + quada_record = await mdns.get_service_by_record_type( + service_name=hostname, + service_type=MdnsServiceType.OPERATIONAL.value, + record_type=DNSRecordType.AAAA, + log_output=True + ) + + answer_record_type = quada_record.get_type(quada_record.type) + quada_record_type = _TYPES[_TYPE_AAAA] + + # Verify AAAA record is returned + asserts.assert_equal(hostname, quada_record.name, f"Server name mismatch: {hostname} vs {quada_record.name}") + asserts.assert_equal(quada_record_type, answer_record_type, + f"Record type should be {quada_record_type} but got {answer_record_type}") + + # # *** STEP 9 *** + # TH verifies the following from the returned records: The hostname must be a fixed-length twelve-character (or sixteen-character) + # hexadecimal string, encoded as ASCII (UTF-8) text using capital letters.. ICD TXT key: • If supports_lit is false, verify that the + # ICD key is NOT present in the TXT record • If supports_lit is true, verify the ICD key IS present in the TXT record, and it has the + # value of 0 or 1 (ASCII) SII TXT key: • If supports_icd is true and supports_lit is false, set sit_mode to true • If supports_icd is + # true and supports_lit is true, set sit_mode to true if ICD=0 otherwise set sit_mode to false • If supports_icd is false, set + # sit_mode to false • If sit_mode is true, verify that the SII key IS present in the TXT record • if the SII key is present, verify + # it is a decimal value with no leading zeros and is less than or equal to 3600000 (1h in ms) SAI TXT key: • if supports_icd is true, + # verify that the SAI key is present in the TXT record • If the SAI key is present, verify it is a decimal value with no leading + # zeros and is less than or equal to 3600000 (1h in ms) + self.step(9) + + # Verify hostname character length (12 or 16) + asserts.assert_true(self.verify_hostname(hostname=hostname), + f"Hostname for '{hostname}' is not a 12 or 16 character uppercase hexadecimal string") + + # ICD TXT KEY + if supports_lit: + logging.info("supports_lit is true, verify the ICD key IS present in the TXT record, and it has the value of 0 or 1 (ASCII).") + + # Verify the ICD key IS present + asserts.assert_in('ICD', operational_record.txt_record, "ICD key is NOT present in the TXT record.") + + # Verify it has the value of 0 or 1 (ASCII) + icd_value = int(operational_record.txt_record['ICD']) + asserts.assert_true(icd_value == 0 or icd_value == 1, "ICD value is different than 0 or 1 (ASCII).") + else: + logging.info("supports_lit is false, verify that the ICD key is NOT present in the TXT record.") + asserts.assert_not_in('ICD', operational_record.txt_record, "ICD key is present in the TXT record.") + + # SII TXT KEY + if supports_icd and not supports_lit: + sit_mode = True + + if supports_icd and supports_lit: + if icd_value == 0: + sit_mode = True + else: + sit_mode = False + + if not supports_icd: + sit_mode = False + + if sit_mode: + logging.info("sit_mode is True, verify the SII key IS present.") + asserts.assert_in('SII', operational_record.txt_record, "SII key is NOT present in the TXT record.") + + logging.info("Verify SII value is a decimal with no leading zeros and is less than or equal to 3600000 (1h in ms).") + sii_value = operational_record.txt_record['SII'] + result, message = self.verify_decimal_value(sii_value, self.ONE_HOUR_IN_MS) + asserts.assert_true(result, message) + + # SAI TXT KEY + if supports_icd: + logging.info("supports_icd is True, verify the SAI key IS present.") + asserts.assert_in('SAI', operational_record.txt_record, "SAI key is NOT present in the TXT record.") + + logging.info("Verify SAI value is a decimal with no leading zeros and is less than or equal to 3600000 (1h in ms).") + sai_value = operational_record.txt_record['SAI'] + result, message = self.verify_decimal_value(sai_value, self.ONE_HOUR_IN_MS) + asserts.assert_true(result, message) + + # SAT TXT KEY + if 'SAT' in operational_record.txt_record: + logging.info( + "SAT key is present in TXT record, verify that it is a decimal value with no leading zeros and is less than or equal to 65535.") + sat_value = operational_record.txt_record['SAT'] + result, message = self.verify_decimal_value(sat_value, self.MAX_SAT_VALUE) + asserts.assert_true(result, message) + + if supports_icd: + logging.info("supports_icd is True, verify the SAT value is equal to active_mode_threshold.") + asserts.assert_equal(int(sat_value), active_mode_threshold_ms) + + # T TXT KEY + if 'T' in operational_record.txt_record: + logging.info("T key is present in TXT record, verify if that it is a decimal value with no leading zeros and is less than or equal to 6. Convert the value to a bitmap and verify bit 0 is clear.") + t_value = operational_record.txt_record['T'] + result, message = self.verify_t_value(t_value) + asserts.assert_true(result, message) + + # AAAA + logging.info("Verify the AAAA record contains at least one IPv6 address") + ipv6_address = self.extract_ipv6_address(str(quada_record)) + result, message = self.is_ipv6_address(ipv6_address) + asserts.assert_true(result, message) + + # # *** STEP 10 *** + # TH performs a DNS-SD browse for _I._sub._matter._tcp.local, where is the 64-bit compressed + # Fabric identifier, expressed as a fixed-length, sixteencharacter hexadecimal string, encoded as ASCII (UTF-8) + # text using capital letters. Verify DUT returns a PTR record with DNS-SD instance name set to instance_name + self.step(10) + service_types = await mdns.get_service_types(log_output=True) + op_sub_type = self.get_operational_subtype() + asserts.assert_in(op_sub_type, service_types, f"No PTR record with DNS-SD instance name '{op_sub_type}' wsa found.") + + # # *** STEP 11 *** + # TH performs a DNS-SD browse for _matter._tcp.local + # Verify DUT returns a PTR record with DNS-SD instance name set to instance_name + self.step(11) + op_service_info = await mdns._get_service( + service_type=MdnsServiceType.OPERATIONAL, + log_output=True, + discovery_timeout_sec=15 + ) + + # Verify DUT returns a PTR record with DNS-SD instance name set instance_name + asserts.assert_equal(op_service_info.server, hostname, + f"No PTR record with DNS-SD instance name '{MdnsServiceType.OPERATIONAL.value}'") + asserts.assert_equal(instance_name, op_service_info.instance_name, "Instance name mismatch") + + +if __name__ == "__main__": + default_matter_test_main() diff --git a/src/python_testing/mdns_discovery/mdns_async_service_info.py b/src/python_testing/mdns_discovery/mdns_async_service_info.py new file mode 100644 index 00000000000000..61d1dc73ebcb04 --- /dev/null +++ b/src/python_testing/mdns_discovery/mdns_async_service_info.py @@ -0,0 +1,135 @@ +# +# Copyright (c) 2024 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import enum +from random import randint +from typing import TYPE_CHECKING, Optional + +from zeroconf import (BadTypeInNameException, DNSOutgoing, DNSQuestion, DNSQuestionType, ServiceInfo, Zeroconf, current_time_millis, + service_type_name) +from zeroconf.const import (_CLASS_IN, _DUPLICATE_QUESTION_INTERVAL, _FLAGS_QR_QUERY, _LISTENER_TIME, _MDNS_PORT, _TYPE_A, + _TYPE_AAAA, _TYPE_SRV, _TYPE_TXT) + +_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) + + +@enum.unique +class DNSRecordType(enum.Enum): + """An MDNS record type. + + "A" - A MDNS record type + "AAAA" - AAAA MDNS record type + "SRV" - SRV MDNS record type + "TXT" - TXT MDNS record type + """ + + A = 0 + AAAA = 1 + SRV = 2 + TXT = 3 + + +class MdnsAsyncServiceInfo(ServiceInfo): + def __init__(self, name: str, type_: str) -> None: + super().__init__(type_=type_, name=name) + + if type_ and not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + + self._name = name + self.type = type_ + self.key = name.lower() + + async def async_request( + self, + zc: Zeroconf, + timeout: float, + question_type: Optional[DNSQuestionType] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + record_type: DNSRecordType = None + ) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + + This method will be run in the event loop. + + Passing addr and port is optional, and will default to the + mDNS multicast address and port. This is useful for directing + requests to a specific host that may be able to respond across + subnets. + """ + if not zc.started: + await zc.async_wait_for_start() + + now = current_time_millis() + + if TYPE_CHECKING: + assert zc.loop is not None + + first_request = True + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + zc.async_add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + this_question_type = question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM + out: DNSOutgoing = self._generate_request_query(this_question_type, record_type) + first_request = False + + if out.questions: + zc.async_send(out, addr, port) + next_ = now + delay + next_ += randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) + + if this_question_type is DNSQuestionType.QM and delay < _DUPLICATE_QUESTION_INTERVAL: + delay = _DUPLICATE_QUESTION_INTERVAL + + await self.async_wait(min(next_, last) - now, zc.loop) + now = current_time_millis() + finally: + zc.async_remove_listener(self) + + return True + + def _generate_request_query(self, question_type: DNSQuestionType, record_type: DNSRecordType) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + name = self._name + qu_question = question_type is DNSQuestionType.QU + + record_types = { + DNSRecordType.SRV: (_TYPE_SRV, "Requesting MDNS SRV record..."), + DNSRecordType.TXT: (_TYPE_TXT, "Requesting MDNS TXT record..."), + DNSRecordType.A: (_TYPE_A, "Requesting MDNS A record..."), + DNSRecordType.AAAA: (_TYPE_AAAA, "Requesting MDNS AAAA record..."), + } + + # Iterate over record types, add question uppon match + for r_type, (type_const, message) in record_types.items(): + if record_type is None or record_type == r_type: + print(message) + question = DNSQuestion(name, type_const, _CLASS_IN) + question.unicast = qu_question + out.add_question(question) + + return out diff --git a/src/python_testing/mdns_discovery/mdns_discovery.py b/src/python_testing/mdns_discovery/mdns_discovery.py index f8c9d46d70760a..2d7337b622a536 100644 --- a/src/python_testing/mdns_discovery/mdns_discovery.py +++ b/src/python_testing/mdns_discovery/mdns_discovery.py @@ -15,15 +15,19 @@ # limitations under the License. # - import asyncio import json import logging from dataclasses import asdict, dataclass from enum import Enum -from typing import Dict, List, Optional - -from zeroconf import IPVersion, ServiceStateChange, Zeroconf +from time import sleep +from typing import Dict, List, Optional, Union, cast + +from mdns_discovery.mdns_async_service_info import DNSRecordType, MdnsAsyncServiceInfo +from zeroconf import IPVersion, ServiceListener, ServiceStateChange, Zeroconf +from zeroconf._dns import DNSRecord +from zeroconf._engine import AsyncListener +from zeroconf._protocol.incoming import DNSIncoming from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconfServiceTypes logger = logging.getLogger(__name__) @@ -34,9 +38,6 @@ class MdnsServiceInfo: # The unique name of the mDNS service. service_name: str - # The service type of the service, typically indicating the service protocol and domain. - service_type: str - # The instance name of the service. instance_name: str @@ -67,6 +68,9 @@ class MdnsServiceInfo: # The time-to-live value for other records associated with the service. other_ttl: int + # The service type of the service, typically indicating the service protocol and domain. + service_type: Optional[str] = None + class MdnsServiceType(Enum): """ @@ -78,6 +82,25 @@ class MdnsServiceType(Enum): BORDER_ROUTER = "_meshcop._udp.local." +class MdnsServiceListener(ServiceListener): + """ + A service listener required during mDNS service discovery + """ + + def __init__(self): + self.updated_event = asyncio.Event() + + def add_service(self, zeroconf: Zeroconf, service_type: str, name: str) -> None: + sleep(0.125) + self.updated_event.set() + + def remove_service(self, zeroconf: Zeroconf, service_type: str, name: str) -> None: + pass + + def update_service(self, zeroconf: Zeroconf, service_type: str, name: str) -> None: + self.updated_event.set() + + class MdnsDiscovery: DISCOVERY_TIMEOUT_SEC = 15 @@ -118,7 +141,7 @@ async def get_commissioner_service(self, log_output: bool = False, """ Asynchronously discovers a commissioner mDNS service within the network. - Args: + Optional args: log_output (bool): Logs the discovered services to the console. Defaults to False. discovery_timeout_sec (float): Defaults to 15 seconds. @@ -134,7 +157,7 @@ async def get_commissionable_service(self, log_output: bool = False, """ Asynchronously discovers a commissionable mDNS service within the network. - Args: + Optional args: log_output (bool): Logs the discovered services to the console. Defaults to False. discovery_timeout_sec (float): Defaults to 15 seconds. @@ -153,16 +176,19 @@ async def get_operational_service(self, """ Asynchronously discovers an operational mDNS service within the network. - Args: + Optional args: log_output (bool): Logs the discovered services to the console. Defaults to False. discovery_timeout_sec (float): Defaults to 15 seconds. - node_id: the node id to create the service name from - compressed_fabric_id: the fabric id to create the service name from + node_id: the node id to create the service name from + compressed_fabric_id: the fabric id to create the service name from Returns: Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached. """ - # Validation to ensure both or none of the parameters are provided + + # Validation to ensure that both node_id and compressed_fabric_id are provided or both are None. + if (node_id is None) != (compressed_fabric_id is None): + raise ValueError("Both node_id and compressed_fabric_id must be provided together or not at all.") self._name_filter = f'{compressed_fabric_id:016x}-{node_id:016x}.{MdnsServiceType.OPERATIONAL.value}'.upper() return await self._get_service(MdnsServiceType.OPERATIONAL, log_output, discovery_timeout_sec) @@ -173,7 +199,7 @@ async def get_border_router_service(self, log_output: bool = False, """ Asynchronously discovers a border router mDNS service within the network. - Args: + Optional args: log_output (bool): Logs the discovered services to the console. Defaults to False. discovery_timeout_sec (float): Defaults to 15 seconds. @@ -188,7 +214,7 @@ async def get_all_services(self, log_output: bool = False, """ Asynchronously discovers all available mDNS services within the network. - Args: + Optional args: log_output (bool): Logs the discovered services to the console. Defaults to False. discovery_timeout_sec (float): Defaults to 15 seconds. @@ -200,6 +226,114 @@ async def get_all_services(self, log_output: bool = False, return self._discovered_services + async def get_service_types(self, log_output: bool = False, discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,) -> List[str]: + """ + Asynchronously discovers all available mDNS services within the network and returns a list + of the service types discovered. This method utilizes the AsyncZeroconfServiceTypes.async_find() + function to perform the network scan for mDNS services. + + Optional args: + log_output (bool): If set to True, the discovered service types are logged to the console. + This can be useful for debugging or informational purposes. Defaults to False. + discovery_timeout_sec (float): The maximum time (in seconds) to wait for the discovery process. Defaults to 10.0 seconds. + + Returns: + List[str]: A list containing the service types (str) of the discovered mDNS services. Each + element in the list is a string representing a unique type of service found during + the discovery process. + """ + try: + discovered_services = list(await asyncio.wait_for(AsyncZeroconfServiceTypes.async_find(), timeout=discovery_timeout_sec)) + except asyncio.TimeoutError: + logger.info(f"MDNS service types discovery timed out after {discovery_timeout_sec} seconds.") + discovered_services = [] + + if log_output: + logger.info(f"MDNS discovered service types: {discovered_services}") + + return discovered_services + + async def get_service_by_record_type(self, service_name: str, + record_type: DNSRecordType, + service_type: str = None, + discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC, + log_output: bool = False + ) -> Union[Optional[MdnsServiceInfo], Optional[DNSRecord]]: + """ + Asynchronously discovers an mDNS service within the network by service name, service type, + and record type. + + Required args: + service_name (str): The unique name of the mDNS service. + Example: + C82B83803DECA0B2-0000000012344321._matter._tcp.local. + + record_type (DNSRecordType): The type of record to look for (SRV, TXT, AAAA, A). + Example: + _matterd._tcp.local. (from the MdnsServiceType enum) + + Optional args: + service_type (str): The service type of the service. Defaults to None. + log_output (bool): Logs the discovered services to the console. Defaults to False. + discovery_timeout_sec (float): Defaults to 15 seconds. + + Returns: + Union[Optional[MdnsServiceInfo], Optional[DNSRecord]]: An instance of MdnsServiceInfo, + a DNSRecord object, or None. + """ + mdns_service_info = None + + if service_type: + logger.info( + f"\nLooking for MDNS service type '{service_type}', service name '{service_name}', record type '{record_type.name}'\n") + else: + logger.info( + f"\nLooking for MDNS service with service name '{service_name}', record type '{record_type.name}'\n") + + # Adds service listener + service_listener = MdnsServiceListener() + self._zc.add_service_listener(MdnsServiceType.OPERATIONAL.value, service_listener) + + # Wait for the add/update service event or timeout + try: + await asyncio.wait_for(service_listener.updated_event.wait(), discovery_timeout_sec) + except asyncio.TimeoutError: + logger.info(f"Service lookup for {service_name} timeout ({discovery_timeout_sec}) reached without an update.") + finally: + self._zc.remove_service_listener(service_listener) + + # Prepare and perform query + service_info = MdnsAsyncServiceInfo(name=service_name, type_=service_type) + is_discovered = await service_info.async_request( + self._zc, + 3000, + record_type=record_type) + + if record_type in [DNSRecordType.A, DNSRecordType.AAAA]: + # Service type not supplied so we can + # query against the target/server + for protocols in self._zc.engine.protocols: + listener = cast(AsyncListener, protocols) + if listener.data: + dns_incoming = DNSIncoming(listener.data) + if dns_incoming.data: + answers = dns_incoming.answers() + logger.info(f"\nIncoming DNSRecord: {answers}\n") + return answers.pop(0) if answers else None + else: + # Adds service to discovered services + if is_discovered: + mdns_service_info = self._to_mdns_service_info_class(service_info) + self._discovered_services = {} + self._discovered_services[service_type] = [] + if mdns_service_info is not None: + self._discovered_services[service_type].append(mdns_service_info) + + if log_output: + self._log_output() + + return mdns_service_info + # Private methods async def _discover(self, discovery_timeout_sec: float, @@ -228,7 +362,7 @@ async def _discover(self, self._event.clear() if all_services: - self._service_types = list(await AsyncZeroconfServiceTypes.async_find()) + self._service_types = list(set(await AsyncZeroconfServiceTypes.async_find())) logger.info(f"Browsing for MDNS service(s) of type: {self._service_types}") @@ -315,8 +449,9 @@ async def _query_service_info(self, zeroconf: Zeroconf, service_type: str, servi mdns_service_info = self._to_mdns_service_info_class(service_info) if service_type not in self._discovered_services: - self._discovered_services[service_type] = [mdns_service_info] - else: + self._discovered_services[service_type] = [] + + if mdns_service_info is not None: self._discovered_services[service_type].append(mdns_service_info) elif self._verbose_logging: logger.warning("Service information not found.") @@ -336,7 +471,7 @@ def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsSer mdns_service_info = MdnsServiceInfo( service_name=service_info.name, service_type=service_info.type, - instance_name=service_info.get_name(), + instance_name=self._get_instance_name(service_info), server=service_info.server, port=service_info.port, addresses=service_info.parsed_addresses(), @@ -350,6 +485,12 @@ def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsSer return mdns_service_info + def _get_instance_name(self, service_info: AsyncServiceInfo) -> str: + if service_info.type: + return service_info.name[: len(service_info.name) - len(service_info.type) - 1] + else: + return service_info.name + async def _get_service(self, service_type: MdnsServiceType, log_output: bool, discovery_timeout_sec: float @@ -380,7 +521,7 @@ async def _get_service(self, service_type: MdnsServiceType, def _log_output(self) -> str: """ - Converts the discovered services to a JSON string and log it. + Converts the discovered services to a JSON string and logs it. The method is intended to be used for debugging or informational purposes, providing a clear and comprehensive view of all services discovered during the mDNS service discovery process. diff --git a/src/python_testing/mdns_discovery/mdns_service_type_enum.py b/src/python_testing/mdns_discovery/mdns_service_type_enum.py deleted file mode 100644 index efd77ca3beef5c..00000000000000 --- a/src/python_testing/mdns_discovery/mdns_service_type_enum.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Copyright (c) 2024 Project CHIP Authors -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from enum import Enum - - -class MdnsServiceType(Enum): - COMMISSIONER = "_matterd._udp.local." - COMMISSIONABLE = "_matterc._udp.local." - OPERATIONAL = "_matter._tcp.local." - BORDER_ROUTER = "_meshcop._udp.local."