Skip to content

Commit

Permalink
Dnssd changes to browse and resolve using open thread domain along wi…
Browse files Browse the repository at this point in the history
…th the local domain
  • Loading branch information
nivi-apple committed Mar 19, 2024
1 parent 27a478f commit 9ea1ed0
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 35 deletions.
30 changes: 14 additions & 16 deletions src/platform/Darwin/DnssdContexts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ namespace {

constexpr uint8_t kDnssdKeyMaxSize = 32;
constexpr uint8_t kDnssdTxtRecordMaxEntries = 20;
constexpr char kLocalDot[] = "local.";

bool IsLocalDomain(const char * domain)
{
return strcmp(kLocalDot, domain) == 0;
}

std::string GetHostNameWithoutDomain(const char * hostnameWithDomain)
{
Expand Down Expand Up @@ -252,6 +246,7 @@ void MdnsContexts::Delete(GenericContext * context)
{
DNSServiceRefDeallocate(context->serviceRef);
}

chip::Platform::Delete(context);
}

Expand Down Expand Up @@ -388,7 +383,6 @@ void BrowseContext::OnBrowseAdd(const char * name, const char * type, const char
ChipLogProgress(Discovery, "Mdns: %s name: %s, type: %s, domain: %s, interface: %" PRIu32, __func__, StringOrNullMarker(name),
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(IsLocalDomain(domain));
auto service = GetService(name, type, protocol, interfaceId);
services.push_back(service);
}
Expand All @@ -399,7 +393,6 @@ void BrowseContext::OnBrowseRemove(const char * name, const char * type, const c
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(name != nullptr);
VerifyOrReturn(IsLocalDomain(domain));

services.erase(std::remove_if(services.begin(), services.end(),
[name, type, interfaceId](const DnssdService & service) {
Expand Down Expand Up @@ -443,7 +436,6 @@ void BrowseWithDelegateContext::OnBrowseAdd(const char * name, const char * type
ChipLogProgress(Discovery, "Mdns: %s name: %s, type: %s, domain: %s, interface: %" PRIu32, __func__, StringOrNullMarker(name),
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(IsLocalDomain(domain));

auto delegate = static_cast<DnssdBrowseDelegate *>(context);
auto service = GetService(name, type, protocol, interfaceId);
Expand All @@ -456,7 +448,6 @@ void BrowseWithDelegateContext::OnBrowseRemove(const char * name, const char * t
StringOrNullMarker(type), StringOrNullMarker(domain), interfaceId);

VerifyOrReturn(name != nullptr);
VerifyOrReturn(IsLocalDomain(domain));

auto delegate = static_cast<DnssdBrowseDelegate *>(context);
auto service = GetService(name, type, protocol, interfaceId);
Expand All @@ -473,6 +464,7 @@ ResolveContext::ResolveContext(void * cbContext, DnssdResolveCallback cb, chip::
callback = cb;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
domainName = std::string(kLocalDot);
consumerCounter = std::move(consumerCounterToUse);
}

Expand All @@ -485,6 +477,7 @@ ResolveContext::ResolveContext(CommissioningResolveDelegate * delegate, chip::In
callback = nullptr;
protocol = GetProtocol(cbAddressType);
instanceName = instanceNameToResolve;
domainName = std::string(kLocalDot);
consumerCounter = std::move(consumerCounterToUse);
}

Expand Down Expand Up @@ -548,7 +541,7 @@ void ResolveContext::DispatchSuccess()

for (auto & interface : interfaces)
{
if (TryReportingResultsForInterfaceIndex(interface.first))
if (TryReportingResultsForInterfaceIndex(interface.first.first))
{
break;
}
Expand All @@ -568,7 +561,8 @@ bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceInde
return false;
}

auto & interface = interfaces[interfaceIndex];
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceIndex, this->domainName);
auto & interface = interfaces[interfaceKey];
auto & ips = interface.addresses;

// Some interface may not have any ips, just ignore them.
Expand Down Expand Up @@ -596,15 +590,17 @@ bool ResolveContext::TryReportingResultsForInterfaceIndex(uint32_t interfaceInde
return true;
}

CHIP_ERROR ResolveContext::OnNewAddress(uint32_t interfaceId, const struct sockaddr * address)
CHIP_ERROR ResolveContext::OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address)
{
// If we don't have any information about this interfaceId, just ignore the
// address, since it won't be usable anyway without things like the port.
// This can happen if "local" is set up as a search domain in the DNS setup
// on the system, because the hostnames we are looking up all end in
// ".local". In other words, we can get regular DNS results in here, not
// just DNS-SD ones.
if (interfaces.find(interfaceId) == interfaces.end())
uint32_t interfaceId = interfaceKey.first;

if (interfaces.find(interfaceKey) == interfaces.end())
{
return CHIP_NO_ERROR;
}
Expand All @@ -627,7 +623,7 @@ CHIP_ERROR ResolveContext::OnNewAddress(uint32_t interfaceId, const struct socka
return CHIP_NO_ERROR;
}

interfaces[interfaceId].addresses.push_back(ip);
interfaces[interfaceKey].addresses.push_back(ip);

return CHIP_NO_ERROR;
}
Expand Down Expand Up @@ -709,7 +705,9 @@ void ResolveContext::OnNewInterface(uint32_t interfaceId, const char * fullname,
// resolving.
interface.fullyQualifiedDomainName = hostnameWithDomain;

interfaces.insert(std::make_pair(interfaceId, std::move(interface)));
std::pair<uint32_t, std::string> interfaceKey = std::make_pair(interfaceId, GetDomainNameFromHostName(hostnameWithDomain));

interfaces.insert(std::make_pair(interfaceKey, std::move(interface)));
}

bool ResolveContext::HasInterface()
Expand Down
135 changes: 118 additions & 17 deletions src/platform/Darwin/DnssdImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@
#include <lib/support/logging/CHIPLogging.h>
#include <platform/CHIPDeviceLayer.h>

using namespace chip;
using namespace chip::Dnssd;
using namespace chip::Dnssd::Internal;

namespace {

constexpr char kLocalDot[] = "local.";
// The extra time in milliseconds that we will wait for the resolution on the open thread domain to complete.
constexpr uint16_t kOpenThreadTimeoutInMsec = 250;

static bool hasOpenThreadTimerStarted = false;

constexpr DNSServiceFlags kRegisterFlags = kDNSServiceFlagsNoAutoRename;
constexpr DNSServiceFlags kBrowseFlags = 0;
constexpr DNSServiceFlags kBrowseFlags = kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kGetAddrInfoFlags = kDNSServiceFlagsTimeout | kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kResolveFlags = kDNSServiceFlagsShareConnection;
constexpr DNSServiceFlags kReconfirmRecordFlags = 0;
Expand All @@ -49,11 +53,27 @@ uint32_t GetInterfaceId(chip::Inet::InterfaceId interfaceId)
return interfaceId.IsPresent() ? interfaceId.GetPlatformInterface() : kDNSServiceInterfaceIndexAny;
}

std::string GetHostNameWithDomain(const char * hostname)
std::string GetHostNameWithLocalDomain(const char * hostname)
{
return std::string(hostname) + '.' + kLocalDot;
}

bool HostNameHasDomain(const char * hostname, const char * domain)
{
size_t domainLength = strlen(domain);
size_t hostnameLength = strlen(hostname);
if (domainLength > hostnameLength)
{
return false;
}
const char * found = strstr(hostname, domain);
if (found == nullptr)
{
return false;
}
return (strncmp(found, domain, domainLength) == 0);
}

void LogOnFailure(const char * name, DNSServiceErrorType err)
{
if (kDNSServiceErr_NoError != err)
Expand Down Expand Up @@ -131,10 +151,54 @@ std::shared_ptr<uint32_t> GetCounterHolder(const char * name)
namespace chip {
namespace Dnssd {


std::string GetDomainNameFromHostName(const char * hostname)
{
if (HostNameHasDomain(hostname, kLocalDot))
{
return std::string(kLocalDot);
}
else if (HostNameHasDomain(hostname, kOpenThreadDot))
{
return std::string(kOpenThreadDot);
}
return std::string();
}

Global<MdnsContexts> MdnsContexts::sInstance;

namespace {

/**
* @brief Callback that is called when the timeout for resolving on the kOpenThreadDot domain has expired.
*
* @param[in] systemLayer The system layer.
* @param[in] callbackContext The context passed to the timer callback.
*/
void OpenThreadTimerExpiredCallback(System::Layer * systemLayer, void * callbackContext)
{
ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
VerifyOrReturn(callbackContext != nullptr && systemLayer != nullptr, ChipLogError(Discovery, "Open thread timer callback context is null"));

auto sdCtx = reinterpret_cast<ResolveContext *>(callbackContext);
VerifyOrReturn(sdCtx != nullptr, ChipLogError(Discovery, "Resolve Context is null"));
sdCtx->Finalize();
hasOpenThreadTimerStarted = false;
}

/**
* @brief Starts a timer to wait for the resolution on the kOpenThreadDot domain to happen.
*
* @param[in] timeoutSeconds The timeout in seconds.
* @param[in] ResolveContext The resolve context.
*/
void StartOpenThreadTimer(uint16_t timeoutInMSecs, ResolveContext * ctx)
{
VerifyOrReturn(ctx != nullptr, ChipLogError(Discovery, "Can't schedule open thread timer since context is null"));
DeviceLayer::SystemLayer().StartTimer(System::Clock::Milliseconds16(timeoutInMSecs),
OpenThreadTimerExpiredCallback, reinterpret_cast<void*>(ctx));
}

static void OnRegister(DNSServiceRef sdRef, DNSServiceFlags flags, DNSServiceErrorType err, const char * name, const char * type,
const char * domain, void * context)
{
Expand Down Expand Up @@ -183,14 +247,24 @@ static void OnBrowse(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t interf

CHIP_ERROR Browse(BrowseHandler * sdCtx, uint32_t interfaceId, const char * type)
{
ChipLogProgress(Discovery, "Browsing for: %s", StringOrNullMarker(type));
DNSServiceRef sdRef;
auto err = DNSServiceBrowse(&sdRef, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

return MdnsContexts::GetInstance().Add(sdCtx, sdRef);
}
// We will browse on both the local domain and the open thread domain.
ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kLocalDot);

auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefLocal, kBrowseFlags, interfaceId, type, kLocalDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

ChipLogProgress(Discovery, "Browsing for: %s on domain %s", StringOrNullMarker(type), kOpenThreadDot);

DNSServiceRef sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceBrowse(&sdRefOpenThread, kBrowseFlags, interfaceId, type, kOpenThreadDot, OnBrowse, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

return MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
}
CHIP_ERROR Browse(void * context, DnssdBrowseCallback callback, uint32_t interfaceId, const char * type,
DnssdServiceProtocol protocol, intptr_t * browseIdentifier)
{
Expand Down Expand Up @@ -219,25 +293,41 @@ static void OnGetAddrInfo(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t i
ReturnOnFailure(MdnsContexts::GetInstance().Has(sdCtx));
LogOnFailure(__func__, err);

sdCtx->domainName = GetDomainNameFromHostName(hostname);
if (kDNSServiceErr_NoError == err)
{
sdCtx->OnNewAddress(interfaceId, address);
std::pair<uint32_t, std::string> key = std::make_pair(interfaceId, sdCtx->domainName);
sdCtx->OnNewAddress(key, address);
}

if (!(flags & kDNSServiceFlagsMoreComing))
{
VerifyOrReturn(sdCtx->HasAddress(), sdCtx->Finalize(kDNSServiceErr_BadState));
sdCtx->Finalize();

if (sdCtx->domainName.compare(kOpenThreadDot) == 0)
{
ChipLogProgress(Discovery, "Mdns: Resolve completed on the open thread domain.");
sdCtx->Finalize();
}
else if (sdCtx->domainName.compare(kLocalDot) == 0)
{
ChipLogProgress(Discovery, "Mdns: Resolve completed on the local domain. Starting a timer for the open thread resolve to come back");
if (!hasOpenThreadTimerStarted)
{
// Schedule a timer to allow the resolve on OpenThread domain to complete.
StartOpenThreadTimer(kOpenThreadTimeoutInMsec, sdCtx);
hasOpenThreadTimerStarted = true;
}
}
}
}

static void GetAddrInfo(ResolveContext * sdCtx)
{
auto protocol = sdCtx->protocol;

for (auto & interface : sdCtx->interfaces)
{
auto interfaceId = interface.first;
auto interfaceId = interface.first.first;
auto hostname = interface.second.fullyQualifiedDomainName.c_str();
auto sdRefCopy = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
auto err = DNSServiceGetAddrInfo(&sdRefCopy, kGetAddrInfoFlags, interfaceId, protocol, hostname, OnGetAddrInfo, sdCtx);
Expand All @@ -263,7 +353,13 @@ static void OnResolve(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t inter
if (!(flags & kDNSServiceFlagsMoreComing))
{
VerifyOrReturn(sdCtx->HasInterface(), sdCtx->Finalize(kDNSServiceErr_BadState));
GetAddrInfo(sdCtx);

// If a resolve was not requested on this context, call GetAddrInfo and set the isResolveRequested flag to true.
if (!sdCtx->isResolveRequested)
{
GetAddrInfo(sdCtx);
sdCtx->isResolveRequested = true;
}
}
}

Expand All @@ -276,8 +372,13 @@ static CHIP_ERROR Resolve(ResolveContext * sdCtx, uint32_t interfaceId, chip::In
auto err = DNSServiceCreateConnection(&sdCtx->serviceRef);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto sdRefCopy = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefCopy, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
// Similar to browse, will try to resolve using both the local domain and the open thread domain.
auto sdRefLocal = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefLocal, kResolveFlags, interfaceId, name, type, kLocalDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto sdRefOpenThread = sdCtx->serviceRef; // Mandatory copy because of kDNSServiceFlagsShareConnection
err = DNSServiceResolve(&sdRefOpenThread, kResolveFlags, interfaceId, name, type, kOpenThreadDot, OnResolve, sdCtx);
VerifyOrReturnError(kDNSServiceErr_NoError == err, sdCtx->Finalize(err));

auto retval = MdnsContexts::GetInstance().Add(sdCtx, sdCtx->serviceRef);
Expand Down Expand Up @@ -339,7 +440,7 @@ CHIP_ERROR ChipDnssdPublishService(const DnssdService * service, DnssdPublishCal

auto regtype = GetFullTypeWithSubTypes(service);
auto interfaceId = GetInterfaceId(service->mInterface);
auto hostname = GetHostNameWithDomain(service->mHostName);
auto hostname = GetHostNameWithLocalDomain(service->mHostName);

return Register(context, callback, interfaceId, regtype.c_str(), service->mName, service->mPort, record, service->mAddressType,
hostname.c_str());
Expand Down Expand Up @@ -485,7 +586,7 @@ CHIP_ERROR ChipDnssdReconfirmRecord(const char * hostname, chip::Inet::IPAddress

auto interfaceId = interface.GetPlatformInterface();
auto rrclass = kDNSServiceClass_IN;
auto fullname = GetHostNameWithDomain(hostname);
auto fullname = GetHostNameWithLocalDomain(hostname);

uint16_t rrtype;
uint16_t rdlen;
Expand Down
12 changes: 10 additions & 2 deletions src/platform/Darwin/DnssdImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
#include <string>
#include <vector>

constexpr char kLocalDot[] = "local.";

constexpr char kOpenThreadDot[] = "openthread.thread.home.arpa";

namespace chip {
namespace Dnssd {

std::string GetDomainNameFromHostName(const char * hostname);

enum class ContextType
{
Register,
Expand Down Expand Up @@ -227,9 +233,11 @@ struct InterfaceInfo
struct ResolveContext : public GenericContext
{
DnssdResolveCallback callback;
std::map<uint32_t, InterfaceInfo> interfaces;
std::map<std::pair<uint32_t, std::string>, InterfaceInfo> interfaces;
DNSServiceProtocol protocol;
std::string instanceName;
std::string domainName;
bool isResolveRequested = false;
std::shared_ptr<uint32_t> consumerCounter;
BrowseContext * const browseThatCausedResolve; // Can be null

Expand All @@ -244,7 +252,7 @@ struct ResolveContext : public GenericContext
void DispatchFailure(const char * errorStr, CHIP_ERROR err) override;
void DispatchSuccess() override;

CHIP_ERROR OnNewAddress(uint32_t interfaceId, const struct sockaddr * address);
CHIP_ERROR OnNewAddress(const std::pair<uint32_t, std::string> interfaceKey, const struct sockaddr * address);
bool HasAddress();

void OnNewInterface(uint32_t interfaceId, const char * fullname, const char * hostname, uint16_t port, uint16_t txtLen,
Expand Down

0 comments on commit 9ea1ed0

Please sign in to comment.