diff --git a/config/esp32/components/chip/Kconfig b/config/esp32/components/chip/Kconfig index 90c9da8ada72f8..ee259a3244b4ca 100644 --- a/config/esp32/components/chip/Kconfig +++ b/config/esp32/components/chip/Kconfig @@ -167,6 +167,13 @@ menu "CHIP Core" A value of 0 disables automatic closing of idle connections. + config ENABLE_ENDPOINT_QUEUE_FILTER + bool "Enable UDP Endpoint queue filter for mDNS Broadcast packets" + depends on USE_MINIMAL_MDNS + default y + help + Enable this option to start a UDP Endpoint queue filter for mDNS Broadcast packets + endmenu # "Networking Options" menu "System Options" diff --git a/src/platform/ESP32/BUILD.gn b/src/platform/ESP32/BUILD.gn index 43803adce27093..2095b7ca710dfa 100644 --- a/src/platform/ESP32/BUILD.gn +++ b/src/platform/ESP32/BUILD.gn @@ -130,6 +130,9 @@ static_library("ESP32") { "DnssdImpl.h", ] } + if (chip_mdns == "minimal") { + sources += [ "ESP32EndpointQueueFilter.h" ] + } } if (chip_enable_ethernet) { diff --git a/src/platform/ESP32/ConnectivityManagerImpl_WiFi.cpp b/src/platform/ESP32/ConnectivityManagerImpl_WiFi.cpp index 48251a27129962..c16c302c0b2b5f 100644 --- a/src/platform/ESP32/ConnectivityManagerImpl_WiFi.cpp +++ b/src/platform/ESP32/ConnectivityManagerImpl_WiFi.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1105,6 +1106,34 @@ void ConnectivityManagerImpl::OnStationIPv6AddressAvailable(const ip_event_got_i event.InterfaceIpAddressChanged.Type = InterfaceIpChangeType::kIpV6_Assigned; PlatformMgr().PostEventOrDie(&event); +#if CONFIG_ENABLE_ENDPOINT_QUEUE_FILTER + uint8_t station_mac[6]; + if (esp_wifi_get_mac(WIFI_IF_STA, station_mac) == ESP_OK) + { + static chip::Inet::ESP32EndpointQueueFilter sEndpointQueueFilter; + char station_mac_str[12]; + for (size_t i = 0; i < 6; ++i) + { + uint8_t dig1 = (station_mac[i] & 0xF0) >> 4; + uint8_t dig2 = station_mac[i] & 0x0F; + station_mac_str[2 * i] = static_cast(dig1 > 9 ? ('A' + dig1 - 0xA) : ('0' + dig1)); + station_mac_str[2 * i + 1] = static_cast(dig2 > 9 ? ('A' + dig2 - 0xA) : ('0' + dig2)); + } + if (sEndpointQueueFilter.SetMdnsHostName(chip::CharSpan(station_mac_str)) == CHIP_NO_ERROR) + { + chip::Inet::UDPEndPointImpl::SetQueueFilter(&sEndpointQueueFilter); + } + else + { + ChipLogError(DeviceLayer, "Failed to set mDNS hostname for endpoint queue filter"); + } + } + else + { + ChipLogError(DeviceLayer, "Failed to get the MAC address of station netif"); + } +#endif // CONFIG_ENABLE_ENDPOINT_QUEUE_FILTER + esp_route_hook_init(esp_netif_get_handle_from_ifkey("WIFI_STA_DEF")); } diff --git a/src/platform/ESP32/ESP32EndpointQueueFilter.h b/src/platform/ESP32/ESP32EndpointQueueFilter.h new file mode 100644 index 00000000000000..031ea29b64d798 --- /dev/null +++ b/src/platform/ESP32/ESP32EndpointQueueFilter.h @@ -0,0 +1,129 @@ +/* + * + * Copyright (c) 2023 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace chip { +namespace Inet { + +class ESP32EndpointQueueFilter : public EndpointQueueFilter +{ +public: + CHIP_ERROR SetMdnsHostName(const chip::CharSpan & hostName) + { + ReturnErrorCodeIf(hostName.size() != sizeof(mHostNameBuffer), CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorCodeIf(!IsValidMdnsHostName(hostName), CHIP_ERROR_INVALID_ARGUMENT); + memcpy(mHostNameBuffer, hostName.data(), hostName.size()); + return CHIP_NO_ERROR; + } + + FilterOutcome FilterBeforeEnqueue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + if (!IsMdnsBroadcastPacket(pktInfo)) + { + return FilterOutcome::kAllowPacket; + } + // Drop the mDNS packets which don't contain 'matter' or ''. + const uint8_t matterBytes[] = { 'm', 'a', 't', 't', 'e', 'r' }; + if (PayloadContains(pktPayload, ByteSpan(matterBytes)) || PayloadContainsHostNameCaseInsensitive(pktPayload)) + { + return FilterOutcome::kAllowPacket; + } + return FilterOutcome::kDropPacket; + } + + FilterOutcome FilterAfterDequeue(const void * endpoint, const IPPacketInfo & pktInfo, + const chip::System::PacketBufferHandle & pktPayload) override + { + return FilterOutcome::kAllowPacket; + } + +private: + // TODO: Add unit tests for these static functions + static bool IsMdnsBroadcastPacket(const IPPacketInfo & pktInfo) + { + if (pktInfo.DestPort == 5353) + { +#if INET_CONFIG_ENABLE_IPV4 + ip_addr_t mdnsIPv4BroadcastAddr = IPADDR4_INIT_BYTES(224, 0, 0, 251); + if (pktInfo.DestAddress == chip::Inet::IPAddress(mdnsIPv4BroadcastAddr)) + { + return true; + } +#endif + ip_addr_t mdnsIPv6BroadcastAddr = IPADDR6_INIT_HOST(0xFF020000, 0, 0, 0xFB); + if (pktInfo.DestAddress == chip::Inet::IPAddress(mdnsIPv6BroadcastAddr)) + { + return true; + } + } + return false; + } + + static bool PayloadContains(const chip::System::PacketBufferHandle & payload, const chip::ByteSpan & byteSpan) + { + if (payload->HasChainedBuffer() || payload->TotalLength() < byteSpan.size()) + { + return false; + } + for (size_t i = 0; i <= payload->TotalLength() - byteSpan.size(); ++i) + { + if (memcmp(payload->Start() + i, byteSpan.data(), byteSpan.size()) == 0) + { + return true; + } + } + return false; + } + + bool PayloadContainsHostNameCaseInsensitive(const chip::System::PacketBufferHandle & payload) + { + uint8_t hostNameLowerCase[12]; + memcpy(hostNameLowerCase, mHostNameBuffer, sizeof(mHostNameBuffer)); + for (size_t i = 0; i < sizeof(hostNameLowerCase); ++i) + { + if (hostNameLowerCase[i] <= 'F' && hostNameLowerCase[i] >= 'A') + { + hostNameLowerCase[i] = static_cast('a' + hostNameLowerCase[i] - 'A'); + } + } + return PayloadContains(payload, ByteSpan(mHostNameBuffer)) || PayloadContains(payload, ByteSpan(hostNameLowerCase)); + } + + static bool IsValidMdnsHostName(const chip::CharSpan & hostName) + { + for (size_t i = 0; i < hostName.size(); ++i) + { + char ch_data = *(hostName.data() + i); + if (!((ch_data >= '0' && ch_data <= '9') || (ch_data >= 'A' && ch_data <= 'F'))) + { + return false; + } + } + return true; + } + + uint8_t mHostNameBuffer[12] = { 0 }; +}; + +} // namespace Inet +} // namespace chip