Skip to content

Commit

Permalink
Steps 8,11
Browse files Browse the repository at this point in the history
  • Loading branch information
raul-marquez-csa committed Jul 19, 2024
1 parent b459840 commit 1bf091e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 56 deletions.
24 changes: 16 additions & 8 deletions src/python_testing/TC_SC_4_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import chip.clusters as Clusters
from matter_testing_support import MatterBaseTest, TestStep, async_test_body, default_matter_test_main
from mdns_discovery.mdns_discovery import DNSRecordType, MdnsDiscovery, MdnsServiceType
from zeroconf.const import _TYPES, _TYPE_AAAA
from mobly import asserts

'''
Expand Down Expand Up @@ -161,8 +162,6 @@ def contains_ipv6_address(addresses):

@async_test_body
async def test_TC_SC_4_3(self):
print("\n"*10)

supports_icd = None
supports_lit = None
active_mode_threshold_ms = None
Expand Down Expand Up @@ -219,11 +218,10 @@ async def test_TC_SC_4_3(self):
service_name=instance_qname,
service_type=MdnsServiceType.OPERATIONAL.value,
record_type=DNSRecordType.SRV,
log_output=True,
load_from_cache=False
log_output=True
)

# Will be used in Step 11
# Will be used in Step 8 and 11
server = operational_record.server

# Verify SRV record is returned
Expand All @@ -238,8 +236,7 @@ async def test_TC_SC_4_3(self):
service_name=instance_qname,
service_type=MdnsServiceType.OPERATIONAL.value,
record_type=DNSRecordType.TXT,
log_output=True,
load_from_cache=False
log_output=True
)

# Verify TXT record is returned and it contains values
Expand All @@ -251,7 +248,18 @@ async def test_TC_SC_4_3(self):
# Verify AAAA record is returned
self.step(8)

# PENDING
quada_record = await mdns.get_service_by_record_type(
service_name=server,
record_type=DNSRecordType.AAAA,
log_output=True
)

answer_record_type = quada_record.get_type(quada_record.type)
quada = _TYPES[_TYPE_AAAA]

# Verify AAAA record is returned
asserts.assert_equal(server, quada_record.name, f"Server name mismatch: {server} vs {quada_record.name}")
asserts.assert_equal(quada, answer_record_type, f"Record type should be {quada} but got {answer_record_type}")

# # *** STEP 9 ***
# TH verifies the following from the returned records: Hostname: • If (MCORE.COM.WIFI OR MCORE.COM.ETH) target, the hostname must be a
Expand Down
47 changes: 28 additions & 19 deletions src/python_testing/mdns_discovery/mdns_async_service_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from zeroconf import (BadTypeInNameException, DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord,
DNSService, DNSText, ServiceInfo, Zeroconf, current_time_millis, service_type_name)
DNSService, DNSText, RecordUpdateListener, ServiceInfo, Zeroconf, current_time_millis, service_type_name)
from zeroconf._utils.net import _encode_address
from zeroconf.const import (_CLASS_IN, _DNS_HOST_TTL, _DNS_OTHER_TTL, _DUPLICATE_QUESTION_INTERVAL, _FLAGS_QR_QUERY, _LISTENER_TIME,
_MDNS_PORT, _TYPE_A, _TYPE_AAAA, _TYPE_SRV, _TYPE_TXT)
Expand Down Expand Up @@ -57,8 +57,9 @@ class DNSRecordType(enum.Enum):
class MdnsAsyncServiceInfo(ServiceInfo):
def __init__(
self,
type_: str,
zc: 'Zeroconf',
name: str,
type_: str = None,
port: Optional[int] = None,
weight: int = 0,
priority: int = 0,
Expand All @@ -74,12 +75,15 @@ def __init__(
# Accept both none, or one, but not both.
if addresses is not None and parsed_addresses is not None:
raise TypeError("addresses and parsed_addresses cannot be provided together")
if not type_.endswith(service_type_name(name, strict=False)):

if type_ and not type_.endswith(service_type_name(name, strict=False)):
raise BadTypeInNameException

self.interface_index = interface_index
self.text = b''
self.type = type_
self._zc = zc
self._name = name
self.type = type_
self.key = name.lower()
self._ipv4_addresses: List[IPv4Address] = []
self._ipv6_addresses: List[IPv6Address] = []
Expand Down Expand Up @@ -109,13 +113,11 @@ def __init__(

async def async_request(
self,
zc: 'Zeroconf',
timeout: float,
question_type: Optional[DNSQuestionType] = None,
addr: Optional[str] = None,
port: int = _MDNS_PORT,
record_type: DNSRecordType = None,
load_from_cache: bool = True
record_type: DNSRecordType = None
) -> bool:
"""Returns true if the service could be discovered on the
network, and updates this object with details discovered.
Expand All @@ -127,38 +129,34 @@ async def async_request(
requests to a specific host that may be able to respond across
subnets.
"""
if not zc.started:
await zc.async_wait_for_start()
if not self._zc.started:
await self._zc.async_wait_for_start()

now = current_time_millis()

if load_from_cache:
if self._load_from_cache(zc, now):
return True

if TYPE_CHECKING:
assert zc.loop is not None
assert self._zc.loop is not None

first_request = True
delay = self._get_initial_delay()
next_ = now
last = now + timeout
try:
zc.async_add_listener(self, None)
self.async_add_listener(self, None)
while not self._is_complete:
if last <= now:
return False
if next_ <= now:
this_question_type = question_type or QU_QUESTION if first_request else QM_QUESTION
out = self._generate_request_query(zc, now, this_question_type, record_type)
out: DNSOutgoing = self._generate_request_query(self._zc, now, this_question_type, record_type)
first_request = False
if out.questions:
# All questions may have been suppressed
# by the question history, so nothing to send,
# but keep waiting for answers in case another
# client on the network is asking the same
# question or they have not arrived yet.
zc.async_send(out, addr, port)
self._zc.async_send(out, addr, port)
next_ = now + delay
next_ += self._get_random_delay()
if this_question_type is QM_QUESTION and delay < _DUPLICATE_QUESTION_INTERVAL:
Expand All @@ -169,13 +167,24 @@ async def async_request(
# history of the remote responder.
delay = _DUPLICATE_QUESTION_INTERVAL

await self.async_wait(min(next_, last) - now, zc.loop)
await self.async_wait(min(next_, last) - now, self._zc.loop)
now = current_time_millis()
finally:
zc.async_remove_listener(self)
self._zc.async_remove_listener(self)

return True

def async_add_listener(
self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
) -> None:
"""Adds a listener for a given question. The listener will have
its update_record method called when information is available to
answer the question(s).
This function is not threadsafe and must be called in the eventloop.
"""
self._zc.record_manager.async_add_listener(listener, question)

def _generate_request_query(
self, zc: 'Zeroconf', now: float_, question_type: DNSQuestionType, record_type: DNSRecordType
) -> DNSOutgoing:
Expand Down
80 changes: 51 additions & 29 deletions src/python_testing/mdns_discovery/mdns_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,26 @@
# limitations under the License.
#


import logging
import asyncio
import json
from dataclasses import asdict, dataclass
from enum import Enum
from time import sleep
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union, cast

from mdns_discovery.mdns_async_service_info import DNSRecordType, MdnsAsyncServiceInfo
from zeroconf import IPVersion, ServiceListener, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconfServiceTypes

from zeroconf._engine import AsyncListener
from zeroconf._protocol.incoming import DNSIncoming
from zeroconf._dns import DNSRecord

@dataclass
class MdnsServiceInfo:
# The unique name of the mDNS service.
service_name: str

# The service type of the service, typically indicating the service protocol and domain.
service_type: str

# The instance name of the service.
instance_name: str

Expand Down Expand Up @@ -66,6 +65,9 @@ class MdnsServiceInfo:
# The time-to-live value for other records associated with the service.
other_ttl: int

# The service type of the service, typically indicating the service protocol and domain.
service_type: Optional[str] = None


class MdnsServiceType(Enum):
"""
Expand Down Expand Up @@ -226,12 +228,11 @@ async def get_service_types(self, log_output: bool = False) -> List[str]:
return discovered_services

async def get_service_by_record_type(self, service_name: str,
service_type: str,
record_type: DNSRecordType,
load_from_cache: bool = True,
service_type: str = None,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
log_output: bool = False
) -> Optional[MdnsServiceInfo]:
) -> Union[Optional[MdnsServiceInfo], Optional[DNSRecord]]:
"""
Asynchronously discovers an mDNS service within the network by service name, service type,
and record type.
Expand All @@ -244,12 +245,17 @@ async def get_service_by_record_type(self, service_name: str,
record_type (DNSRecordType): The type of record to look for (SRV, TXT, AAAA, A).
Returns:
Optional[MdnsServiceInfo]: An instance of MdnsServiceInfo or None if timeout reached.
Union[Optional[MdnsServiceInfo], Optional[DNSRecord]]: An instance of MdnsServiceInfo,
a DNSRecord object, or None.
"""
mdns_service_info = None

print(
f"Looking for MDNS service type '{service_type}', service name '{service_name}', record type '{record_type.name}'")
if service_type:
print(
f"\nLooking for MDNS service type '{service_type}', service name '{service_name}', record type '{record_type.name}'\n")
else:
print(
f"\nLooking for MDNS service with service name '{service_name}', record type '{record_type.name}'\n")

# Adds service listener
service_listener = MdnsServiceListener()
Expand All @@ -263,26 +269,36 @@ async def get_service_by_record_type(self, service_name: str,
finally:
self._zc.remove_service_listener(service_listener)

# Get service info
service_info = MdnsAsyncServiceInfo(service_type, service_name)
# Prepare and perform query
service_info = MdnsAsyncServiceInfo(self._zc, name=service_name, type_=service_type)
is_discovered = await service_info.async_request(
self._zc,
3000,
record_type=record_type,
load_from_cache=load_from_cache)
record_type=record_type)

if not service_type:
# Service type not supplied so we can
# query against the target/server
for protocols in self._zc.engine.protocols:
listener = cast(AsyncListener, protocols)
if listener.data:
dns_incoming = DNSIncoming(listener.data)
if dns_incoming.data:
answers = dns_incoming.answers()
print(f"\nIncoming DNSRecord: {answers}\n")
return answers.pop(0) if answers else None
else:
# Adds service to discovered services
if is_discovered:
mdns_service_info = self._to_mdns_service_info_class(service_info)
self._discovered_services = {}
self._discovered_services[service_type] = []
if mdns_service_info is not None:
self._discovered_services[service_type].append(mdns_service_info)

# Adds service to discovered services
if is_discovered:
mdns_service_info = self._to_mdns_service_info_class(service_info)
self._discovered_services = {}
self._discovered_services[service_type] = []
if mdns_service_info is not None:
self._discovered_services[service_type].append(mdns_service_info)
if log_output:
self._log_output()

if log_output:
self._log_output()

return mdns_service_info
return mdns_service_info

# Private methods
async def _discover(self,
Expand Down Expand Up @@ -404,7 +420,7 @@ def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsSer
mdns_service_info = MdnsServiceInfo(
service_name=service_info.name,
service_type=service_info.type,
instance_name=service_info.get_name(),
instance_name=self._get_instance_name(service_info),
server=service_info.server,
port=service_info.port,
addresses=service_info.parsed_addresses(),
Expand All @@ -418,6 +434,12 @@ def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsSer

return mdns_service_info

def _get_instance_name(self, service_info: AsyncServiceInfo) -> str:
if service_info.type:
return service_info.name[: len(service_info.name) - len(service_info.type) - 1]
else:
return service_info.name

async def _get_service(self, service_type: MdnsServiceType,
log_output: bool,
discovery_timeout_sec: float
Expand Down

0 comments on commit 1bf091e

Please sign in to comment.