From 6339d8b8d67b120969ae2a45f30f43ed9b12cbca Mon Sep 17 00:00:00 2001 From: Phil Brookes Date: Wed, 10 Jul 2024 10:43:17 +0200 Subject: [PATCH] copy azure external-dns provider --- internal/external-dns/azure/azure.go | 462 +++++++++++++++ .../external-dns/azure/azure_private_dns.go | 445 ++++++++++++++ .../azure/azure_privatedns_test.go | 432 ++++++++++++++ internal/external-dns/azure/azure_test.go | 558 ++++++++++++++++++ internal/external-dns/azure/common.go | 47 ++ internal/external-dns/azure/common_test.go | 87 +++ internal/external-dns/azure/config.go | 154 +++++ internal/external-dns/azure/config_test.go | 46 ++ 8 files changed, 2231 insertions(+) create mode 100644 internal/external-dns/azure/azure.go create mode 100644 internal/external-dns/azure/azure_private_dns.go create mode 100644 internal/external-dns/azure/azure_privatedns_test.go create mode 100644 internal/external-dns/azure/azure_test.go create mode 100644 internal/external-dns/azure/common.go create mode 100644 internal/external-dns/azure/common_test.go create mode 100644 internal/external-dns/azure/config.go create mode 100644 internal/external-dns/azure/config_test.go diff --git a/internal/external-dns/azure/azure.go b/internal/external-dns/azure/azure.go new file mode 100644 index 00000000..dc47d8c0 --- /dev/null +++ b/internal/external-dns/azure/azure.go @@ -0,0 +1,462 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint:staticcheck // Required due to the current dependency on a deprecated version of azure-sdk-for-go +package azure + +import ( + "context" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" + + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/plan" + "sigs.k8s.io/external-dns/provider" +) + +const ( + azureRecordTTL = 300 +) + +// ZonesClient is an interface of dns.ZoneClient that can be stubbed for testing. +type ZonesClient interface { + NewListByResourceGroupPager(resourceGroupName string, options *dns.ZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[dns.ZonesClientListByResourceGroupResponse] +} + +// RecordSetsClient is an interface of dns.RecordSetsClient that can be stubbed for testing. +type RecordSetsClient interface { + NewListAllByDNSZonePager(resourceGroupName string, zoneName string, options *dns.RecordSetsClientListAllByDNSZoneOptions) *azcoreruntime.Pager[dns.RecordSetsClientListAllByDNSZoneResponse] + Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, options *dns.RecordSetsClientDeleteOptions) (dns.RecordSetsClientDeleteResponse, error) + CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, options *dns.RecordSetsClientCreateOrUpdateOptions) (dns.RecordSetsClientCreateOrUpdateResponse, error) +} + +// AzureProvider implements the DNS provider for Microsoft's Azure cloud platform. +type AzureProvider struct { + provider.BaseProvider + domainFilter endpoint.DomainFilter + zoneNameFilter endpoint.DomainFilter + zoneIDFilter provider.ZoneIDFilter + dryRun bool + resourceGroup string + userAssignedIdentityClientID string + zonesClient ZonesClient + recordSetsClient RecordSetsClient +} + +// NewAzureProvider creates a new Azure provider. +// +// Returns the provider or an error if a provider could not be created. +func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzureProvider, error) { + cfg, err := getConfig(configFile, resourceGroup, userAssignedIdentityClientID) + if err != nil { + return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) + } + cred, clientOpts, err := getCredentials(*cfg) + if err != nil { + return nil, fmt.Errorf("failed to get credentials: %w", err) + } + + zonesClient, err := dns.NewZonesClient(cfg.SubscriptionID, cred, clientOpts) + if err != nil { + return nil, err + } + recordSetsClient, err := dns.NewRecordSetsClient(cfg.SubscriptionID, cred, clientOpts) + if err != nil { + return nil, err + } + return &AzureProvider{ + domainFilter: domainFilter, + zoneNameFilter: zoneNameFilter, + zoneIDFilter: zoneIDFilter, + dryRun: dryRun, + resourceGroup: cfg.ResourceGroup, + userAssignedIdentityClientID: cfg.UserAssignedIdentityID, + zonesClient: zonesClient, + recordSetsClient: recordSetsClient, + }, nil +} + +// Records gets the current records. +// +// Returns the current records or an error if the operation failed. +func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { + zones, err := p.zones(ctx) + if err != nil { + return nil, err + } + + for _, zone := range zones { + pager := p.recordSetsClient.NewListAllByDNSZonePager(p.resourceGroup, *zone.Name, &dns.RecordSetsClientListAllByDNSZoneOptions{Top: nil}) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, recordSet := range nextResult.Value { + if recordSet.Name == nil || recordSet.Type == nil { + log.Error("Skipping invalid record set with nil name or type.") + continue + } + recordType := strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/dnszones/") + if !p.SupportedRecordType(recordType) { + continue + } + name := formatAzureDNSName(*recordSet.Name, *zone.Name) + if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) { + log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name) + continue + } + targets := extractAzureTargets(recordSet) + if len(targets) == 0 { + log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) + continue + } + var ttl endpoint.TTL + if recordSet.Properties.TTL != nil { + ttl = endpoint.TTL(*recordSet.Properties.TTL) + } + ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) + log.Debugf( + "Found %s record for '%s' with target '%s'.", + ep.RecordType, + ep.DNSName, + ep.Targets, + ) + endpoints = append(endpoints, ep) + } + } + } + return endpoints, nil +} + +// ApplyChanges applies the given changes. +// +// Returns nil if the operation was successful or an error if the operation failed. +func (p *AzureProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { + zones, err := p.zones(ctx) + if err != nil { + return err + } + + deleted, updated := p.mapChanges(zones, changes) + p.deleteRecords(ctx, deleted) + p.updateRecords(ctx, updated) + return nil +} + +func (p *AzureProvider) zones(ctx context.Context) ([]dns.Zone, error) { + log.Debugf("Retrieving Azure DNS zones for resource group: %s.", p.resourceGroup) + var zones []dns.Zone + pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, &dns.ZonesClientListByResourceGroupOptions{Top: nil}) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, zone := range nextResult.Value { + if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { + zones = append(zones, *zone) + } else if zone.Name != nil && len(p.zoneNameFilter.Filters) > 0 && p.zoneNameFilter.Match(*zone.Name) { + // Handle zoneNameFilter + zones = append(zones, *zone) + } + } + } + log.Debugf("Found %d Azure DNS zone(s).", len(zones)) + return zones, nil +} + +func (p *AzureProvider) SupportedRecordType(recordType string) bool { + switch recordType { + case "MX": + return true + default: + return provider.SupportedRecordType(recordType) + } +} + +type azureChangeMap map[string][]*endpoint.Endpoint + +func (p *AzureProvider) mapChanges(zones []dns.Zone, changes *plan.Changes) (azureChangeMap, azureChangeMap) { + ignored := map[string]bool{} + deleted := azureChangeMap{} + updated := azureChangeMap{} + zoneNameIDMapper := provider.ZoneIDName{} + for _, z := range zones { + if z.Name != nil { + zoneNameIDMapper.Add(*z.Name, *z.Name) + } + } + mapChange := func(changeMap azureChangeMap, change *endpoint.Endpoint) { + zone, _ := zoneNameIDMapper.FindZone(change.DNSName) + if zone == "" { + if _, ok := ignored[change.DNSName]; !ok { + ignored[change.DNSName] = true + log.Infof("Ignoring changes to '%s' because a suitable Azure DNS zone was not found.", change.DNSName) + } + return + } + // Ensure the record type is suitable + changeMap[zone] = append(changeMap[zone], change) + } + + for _, change := range changes.Delete { + mapChange(deleted, change) + } + + for _, change := range changes.Create { + mapChange(updated, change) + } + + for _, change := range changes.UpdateNew { + mapChange(updated, change) + } + return deleted, updated +} + +func (p *AzureProvider) deleteRecords(ctx context.Context, deleted azureChangeMap) { + // Delete records first + for zone, endpoints := range deleted { + for _, ep := range endpoints { + name := p.recordSetNameForZone(zone, ep) + if !p.domainFilter.Match(ep.DNSName) { + log.Debugf("Skipping deletion of record %s because it was filtered out by the specified --domain-filter", ep.DNSName) + continue + } + if p.dryRun { + log.Infof("Would delete %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone) + } else { + log.Infof("Deleting %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone) + if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, name, dns.RecordType(ep.RecordType), nil); err != nil { + log.Errorf( + "Failed to delete %s record named '%s' for Azure DNS zone '%s': %v", + ep.RecordType, + name, + zone, + err, + ) + } + } + } + } +} + +func (p *AzureProvider) updateRecords(ctx context.Context, updated azureChangeMap) { + for zone, endpoints := range updated { + for _, ep := range endpoints { + name := p.recordSetNameForZone(zone, ep) + if !p.domainFilter.Match(ep.DNSName) { + log.Debugf("Skipping update of record %s because it was filtered out by the specified --domain-filter", ep.DNSName) + continue + } + if p.dryRun { + log.Infof( + "Would update %s record named '%s' to '%s' for Azure DNS zone '%s'.", + ep.RecordType, + name, + ep.Targets, + zone, + ) + continue + } + + log.Infof( + "Updating %s record named '%s' to '%s' for Azure DNS zone '%s'.", + ep.RecordType, + name, + ep.Targets, + zone, + ) + + recordSet, err := p.newRecordSet(ep) + if err == nil { + _, err = p.recordSetsClient.CreateOrUpdate( + ctx, + p.resourceGroup, + zone, + name, + dns.RecordType(ep.RecordType), + recordSet, + nil, + ) + } + if err != nil { + log.Errorf( + "Failed to update %s record named '%s' to '%s' for DNS zone '%s': %v", + ep.RecordType, + name, + ep.Targets, + zone, + err, + ) + } + } + } +} + +func (p *AzureProvider) recordSetNameForZone(zone string, endpoint *endpoint.Endpoint) string { + // Remove the zone from the record set + name := endpoint.DNSName + name = name[:len(name)-len(zone)] + name = strings.TrimSuffix(name, ".") + + // For root, use @ + if name == "" { + return "@" + } + return name +} + +func (p *AzureProvider) newRecordSet(endpoint *endpoint.Endpoint) (dns.RecordSet, error) { + var ttl int64 = azureRecordTTL + if endpoint.RecordTTL.IsConfigured() { + ttl = int64(endpoint.RecordTTL) + } + switch dns.RecordType(endpoint.RecordType) { + case dns.RecordTypeA: + aRecords := make([]*dns.ARecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + aRecords[i] = &dns.ARecord{ + IPv4Address: to.Ptr(target), + } + } + return dns.RecordSet{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, + }, + }, nil + case dns.RecordTypeAAAA: + aaaaRecords := make([]*dns.AaaaRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + aaaaRecords[i] = &dns.AaaaRecord{ + IPv6Address: to.Ptr(target), + } + } + return dns.RecordSet{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + AaaaRecords: aaaaRecords, + }, + }, nil + case dns.RecordTypeCNAME: + return dns.RecordSet{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + CnameRecord: &dns.CnameRecord{ + Cname: to.Ptr(endpoint.Targets[0]), + }, + }, + }, nil + case dns.RecordTypeMX: + mxRecords := make([]*dns.MxRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + mxRecord, err := parseMxTarget[dns.MxRecord](target) + if err != nil { + return dns.RecordSet{}, err + } + mxRecords[i] = &mxRecord + } + return dns.RecordSet{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + MxRecords: mxRecords, + }, + }, nil + case dns.RecordTypeTXT: + return dns.RecordSet{ + Properties: &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*dns.TxtRecord{ + { + Value: []*string{ + &endpoint.Targets[0], + }, + }, + }, + }, + }, nil + } + return dns.RecordSet{}, fmt.Errorf("unsupported record type '%s'", endpoint.RecordType) +} + +// Helper function (shared with test code) +func formatAzureDNSName(recordName, zoneName string) string { + if recordName == "@" { + return zoneName + } + return fmt.Sprintf("%s.%s", recordName, zoneName) +} + +// Helper function (shared with text code) +func extractAzureTargets(recordSet *dns.RecordSet) []string { + properties := recordSet.Properties + if properties == nil { + return []string{} + } + + // Check for A records + aRecords := properties.ARecords + if len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil { + targets := make([]string, len(aRecords)) + for i, aRecord := range aRecords { + targets[i] = *aRecord.IPv4Address + } + return targets + } + + // Check for AAAA records + aaaaRecords := properties.AaaaRecords + if len(aaaaRecords) > 0 && (aaaaRecords)[0].IPv6Address != nil { + targets := make([]string, len(aaaaRecords)) + for i, aaaaRecord := range aaaaRecords { + targets[i] = *aaaaRecord.IPv6Address + } + return targets + } + + // Check for CNAME records + cnameRecord := properties.CnameRecord + if cnameRecord != nil && cnameRecord.Cname != nil { + return []string{*cnameRecord.Cname} + } + + // Check for MX records + mxRecords := properties.MxRecords + if len(mxRecords) > 0 && (mxRecords)[0].Exchange != nil { + targets := make([]string, len(mxRecords)) + for i, mxRecord := range mxRecords { + targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange) + } + return targets + } + + // Check for TXT records + txtRecords := properties.TxtRecords + if len(txtRecords) > 0 && (txtRecords)[0].Value != nil { + values := (txtRecords)[0].Value + if len(values) > 0 { + return []string{*(values)[0]} + } + } + return []string{} +} diff --git a/internal/external-dns/azure/azure_private_dns.go b/internal/external-dns/azure/azure_private_dns.go new file mode 100644 index 00000000..50df066f --- /dev/null +++ b/internal/external-dns/azure/azure_private_dns.go @@ -0,0 +1,445 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint:staticcheck // Required due to the current dependency on a deprecated version of azure-sdk-for-go +package azure + +import ( + "context" + "fmt" + "strings" + + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + log "github.com/sirupsen/logrus" + + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/plan" + "sigs.k8s.io/external-dns/provider" +) + +// PrivateZonesClient is an interface of privatedns.PrivateZoneClient that can be stubbed for testing. +type PrivateZonesClient interface { + NewListByResourceGroupPager(resourceGroupName string, options *privatedns.PrivateZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[privatedns.PrivateZonesClientListByResourceGroupResponse] +} + +// PrivateRecordSetsClient is an interface of privatedns.RecordSetsClient that can be stubbed for testing. +type PrivateRecordSetsClient interface { + NewListPager(resourceGroupName string, privateZoneName string, options *privatedns.RecordSetsClientListOptions) *azcoreruntime.Pager[privatedns.RecordSetsClientListResponse] + Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, options *privatedns.RecordSetsClientDeleteOptions) (privatedns.RecordSetsClientDeleteResponse, error) + CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, options *privatedns.RecordSetsClientCreateOrUpdateOptions) (privatedns.RecordSetsClientCreateOrUpdateResponse, error) +} + +// AzurePrivateDNSProvider implements the DNS provider for Microsoft's Azure Private DNS service +type AzurePrivateDNSProvider struct { + provider.BaseProvider + domainFilter endpoint.DomainFilter + zoneIDFilter provider.ZoneIDFilter + dryRun bool + resourceGroup string + userAssignedIdentityClientID string + zonesClient PrivateZonesClient + recordSetsClient PrivateRecordSetsClient +} + +// NewAzurePrivateDNSProvider creates a new Azure Private DNS provider. +// +// Returns the provider or an error if a provider could not be created. +func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, resourceGroup, userAssignedIdentityClientID string, dryRun bool) (*AzurePrivateDNSProvider, error) { + cfg, err := getConfig(configFile, resourceGroup, userAssignedIdentityClientID) + if err != nil { + return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) + } + cred, clientOpts, err := getCredentials(*cfg) + if err != nil { + return nil, fmt.Errorf("failed to get credentials: %w", err) + } + + zonesClient, err := privatedns.NewPrivateZonesClient(cfg.SubscriptionID, cred, clientOpts) + if err != nil { + return nil, err + } + recordSetsClient, err := privatedns.NewRecordSetsClient(cfg.SubscriptionID, cred, clientOpts) + if err != nil { + return nil, err + } + return &AzurePrivateDNSProvider{ + domainFilter: domainFilter, + zoneIDFilter: zoneIDFilter, + dryRun: dryRun, + resourceGroup: cfg.ResourceGroup, + userAssignedIdentityClientID: cfg.UserAssignedIdentityID, + zonesClient: zonesClient, + recordSetsClient: recordSetsClient, + }, nil +} + +// Records gets the current records. +// +// Returns the current records or an error if the operation failed. +func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) { + zones, err := p.zones(ctx) + if err != nil { + return nil, err + } + + log.Debugf("Retrieving Azure Private DNS Records for resource group '%s'", p.resourceGroup) + + for _, zone := range zones { + pager := p.recordSetsClient.NewListPager(p.resourceGroup, *zone.Name, &privatedns.RecordSetsClientListOptions{Top: nil}) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + + for _, recordSet := range nextResult.Value { + var recordType string + if recordSet.Type == nil { + log.Debugf("Skipping invalid record set with missing type.") + continue + } + recordType = strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/privateDnsZones/") + + var name string + if recordSet.Name == nil { + log.Debugf("Skipping invalid record set with missing name.") + continue + } + name = formatAzureDNSName(*recordSet.Name, *zone.Name) + + targets := extractAzurePrivateDNSTargets(recordSet) + if len(targets) == 0 { + log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType) + continue + } + + var ttl endpoint.TTL + if recordSet.Properties.TTL != nil { + ttl = endpoint.TTL(*recordSet.Properties.TTL) + } + + ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...) + log.Debugf( + "Found %s record for '%s' with target '%s'.", + ep.RecordType, + ep.DNSName, + ep.Targets, + ) + endpoints = append(endpoints, ep) + } + } + } + + log.Debugf("Returning %d Azure Private DNS Records for resource group '%s'", len(endpoints), p.resourceGroup) + + return endpoints, nil +} + +// ApplyChanges applies the given changes. +// +// Returns nil if the operation was successful or an error if the operation failed. +func (p *AzurePrivateDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { + log.Debugf("Received %d changes to process", len(changes.Create)+len(changes.Delete)+len(changes.UpdateNew)+len(changes.UpdateOld)) + + zones, err := p.zones(ctx) + if err != nil { + return err + } + + deleted, updated := p.mapChanges(zones, changes) + p.deleteRecords(ctx, deleted) + p.updateRecords(ctx, updated) + return nil +} + +func (p *AzurePrivateDNSProvider) zones(ctx context.Context) ([]privatedns.PrivateZone, error) { + log.Debugf("Retrieving Azure Private DNS zones for Resource Group '%s'", p.resourceGroup) + + var zones []privatedns.PrivateZone + + pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, &privatedns.PrivateZonesClientListByResourceGroupOptions{Top: nil}) + for pager.More() { + nextResult, err := pager.NextPage(ctx) + if err != nil { + return nil, err + } + for _, zone := range nextResult.Value { + log.Debugf("Validating Zone: %v", *zone.Name) + + if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) { + zones = append(zones, *zone) + } + } + } + + log.Debugf("Found %d Azure Private DNS zone(s).", len(zones)) + return zones, nil +} + +type azurePrivateDNSChangeMap map[string][]*endpoint.Endpoint + +func (p *AzurePrivateDNSProvider) mapChanges(zones []privatedns.PrivateZone, changes *plan.Changes) (azurePrivateDNSChangeMap, azurePrivateDNSChangeMap) { + ignored := map[string]bool{} + deleted := azurePrivateDNSChangeMap{} + updated := azurePrivateDNSChangeMap{} + zoneNameIDMapper := provider.ZoneIDName{} + for _, z := range zones { + if z.Name != nil { + zoneNameIDMapper.Add(*z.Name, *z.Name) + } + } + mapChange := func(changeMap azurePrivateDNSChangeMap, change *endpoint.Endpoint) { + zone, _ := zoneNameIDMapper.FindZone(change.DNSName) + if zone == "" { + if _, ok := ignored[change.DNSName]; !ok { + ignored[change.DNSName] = true + log.Infof("Ignoring changes to '%s' because a suitable Azure Private DNS zone was not found.", change.DNSName) + } + return + } + // Ensure the record type is suitable + changeMap[zone] = append(changeMap[zone], change) + } + + for _, change := range changes.Delete { + mapChange(deleted, change) + } + + for _, change := range changes.Create { + mapChange(updated, change) + } + + for _, change := range changes.UpdateNew { + mapChange(updated, change) + } + return deleted, updated +} + +func (p *AzurePrivateDNSProvider) deleteRecords(ctx context.Context, deleted azurePrivateDNSChangeMap) { + log.Debugf("Records to be deleted: %d", len(deleted)) + // Delete records first + for zone, endpoints := range deleted { + for _, ep := range endpoints { + name := p.recordSetNameForZone(zone, ep) + if p.dryRun { + log.Infof("Would delete %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone) + } else { + log.Infof("Deleting %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone) + if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, privatedns.RecordType(ep.RecordType), name, nil); err != nil { + log.Errorf( + "Failed to delete %s record named '%s' for Azure Private DNS zone '%s': %v", + ep.RecordType, + name, + zone, + err, + ) + } + } + } + } +} + +func (p *AzurePrivateDNSProvider) updateRecords(ctx context.Context, updated azurePrivateDNSChangeMap) { + log.Debugf("Records to be updated: %d", len(updated)) + for zone, endpoints := range updated { + for _, ep := range endpoints { + name := p.recordSetNameForZone(zone, ep) + if p.dryRun { + log.Infof( + "Would update %s record named '%s' to '%s' for Azure Private DNS zone '%s'.", + ep.RecordType, + name, + ep.Targets, + zone, + ) + continue + } + + log.Infof( + "Updating %s record named '%s' to '%s' for Azure Private DNS zone '%s'.", + ep.RecordType, + name, + ep.Targets, + zone, + ) + + recordSet, err := p.newRecordSet(ep) + if err == nil { + _, err = p.recordSetsClient.CreateOrUpdate( + ctx, + p.resourceGroup, + zone, + privatedns.RecordType(ep.RecordType), + name, + recordSet, + nil, + ) + } + if err != nil { + log.Errorf( + "Failed to update %s record named '%s' to '%s' for Azure Private DNS zone '%s': %v", + ep.RecordType, + name, + ep.Targets, + zone, + err, + ) + } + } + } +} + +func (p *AzurePrivateDNSProvider) recordSetNameForZone(zone string, endpoint *endpoint.Endpoint) string { + // Remove the zone from the record set + name := endpoint.DNSName + name = name[:len(name)-len(zone)] + name = strings.TrimSuffix(name, ".") + + // For root, use @ + if name == "" { + return "@" + } + return name +} + +func (p *AzurePrivateDNSProvider) newRecordSet(endpoint *endpoint.Endpoint) (privatedns.RecordSet, error) { + var ttl int64 = azureRecordTTL + if endpoint.RecordTTL.IsConfigured() { + ttl = int64(endpoint.RecordTTL) + } + switch privatedns.RecordType(endpoint.RecordType) { + case privatedns.RecordTypeA: + aRecords := make([]*privatedns.ARecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + aRecords[i] = &privatedns.ARecord{ + IPv4Address: to.Ptr(target), + } + } + return privatedns.RecordSet{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, + }, + }, nil + case privatedns.RecordTypeAAAA: + aaaaRecords := make([]*privatedns.AaaaRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + aaaaRecords[i] = &privatedns.AaaaRecord{ + IPv6Address: to.Ptr(target), + } + } + return privatedns.RecordSet{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + AaaaRecords: aaaaRecords, + }, + }, nil + case privatedns.RecordTypeCNAME: + return privatedns.RecordSet{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + CnameRecord: &privatedns.CnameRecord{ + Cname: to.Ptr(endpoint.Targets[0]), + }, + }, + }, nil + case privatedns.RecordTypeMX: + mxRecords := make([]*privatedns.MxRecord, len(endpoint.Targets)) + for i, target := range endpoint.Targets { + mxRecord, err := parseMxTarget[privatedns.MxRecord](target) + if err != nil { + return privatedns.RecordSet{}, err + } + mxRecords[i] = &mxRecord + } + return privatedns.RecordSet{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + MxRecords: mxRecords, + }, + }, nil + case privatedns.RecordTypeTXT: + return privatedns.RecordSet{ + Properties: &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*privatedns.TxtRecord{ + { + Value: []*string{ + &endpoint.Targets[0], + }, + }, + }, + }, + }, nil + } + return privatedns.RecordSet{}, fmt.Errorf("unsupported record type '%s'", endpoint.RecordType) +} + +// Helper function (shared with test code) +func extractAzurePrivateDNSTargets(recordSet *privatedns.RecordSet) []string { + properties := recordSet.Properties + if properties == nil { + return []string{} + } + + // Check for A records + aRecords := properties.ARecords + if len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil { + targets := make([]string, len(aRecords)) + for i, aRecord := range aRecords { + targets[i] = *aRecord.IPv4Address + } + return targets + } + + // Check for AAAA records + aaaaRecords := properties.AaaaRecords + if len(aaaaRecords) > 0 && (aaaaRecords)[0].IPv6Address != nil { + targets := make([]string, len(aaaaRecords)) + for i, aaaaRecord := range aaaaRecords { + targets[i] = *aaaaRecord.IPv6Address + } + return targets + } + + // Check for CNAME records + cnameRecord := properties.CnameRecord + if cnameRecord != nil && cnameRecord.Cname != nil { + return []string{*cnameRecord.Cname} + } + + // Check for MX records + mxRecords := properties.MxRecords + if len(mxRecords) > 0 && (mxRecords)[0].Exchange != nil { + targets := make([]string, len(mxRecords)) + for i, mxRecord := range mxRecords { + targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange) + } + return targets + } + + // Check for TXT records + txtRecords := properties.TxtRecords + if len(txtRecords) > 0 && (txtRecords)[0].Value != nil { + values := (txtRecords)[0].Value + if len(values) > 0 { + return []string{*(values)[0]} + } + } + return []string{} +} diff --git a/internal/external-dns/azure/azure_privatedns_test.go b/internal/external-dns/azure/azure_privatedns_test.go new file mode 100644 index 00000000..567badea --- /dev/null +++ b/internal/external-dns/azure/azure_privatedns_test.go @@ -0,0 +1,432 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + "testing" + + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/plan" + "sigs.k8s.io/external-dns/provider" +) + +const ( + recordTTL = 300 +) + +// mockPrivateZonesClient implements the methods of the Azure Private DNS Zones Client which are used in the Azure Private DNS Provider +// and returns static results which are defined per test +type mockPrivateZonesClient struct { + pagingHandler azcoreruntime.PagingHandler[privatedns.PrivateZonesClientListByResourceGroupResponse] +} + +func newMockPrivateZonesClient(zones []*privatedns.PrivateZone) mockPrivateZonesClient { + pagingHandler := azcoreruntime.PagingHandler[privatedns.PrivateZonesClientListByResourceGroupResponse]{ + More: func(resp privatedns.PrivateZonesClientListByResourceGroupResponse) bool { + return false + }, + Fetcher: func(context.Context, *privatedns.PrivateZonesClientListByResourceGroupResponse) (privatedns.PrivateZonesClientListByResourceGroupResponse, error) { + return privatedns.PrivateZonesClientListByResourceGroupResponse{ + PrivateZoneListResult: privatedns.PrivateZoneListResult{ + Value: zones, + }, + }, nil + }, + } + return mockPrivateZonesClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockPrivateZonesClient) NewListByResourceGroupPager(resourceGroupName string, options *privatedns.PrivateZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[privatedns.PrivateZonesClientListByResourceGroupResponse] { + return azcoreruntime.NewPager(client.pagingHandler) +} + +// mockPrivateRecordSetsClient implements the methods of the Azure Private DNS RecordSet Client which are used in the Azure Private DNS Provider +// and returns static results which are defined per test +type mockPrivateRecordSetsClient struct { + pagingHandler azcoreruntime.PagingHandler[privatedns.RecordSetsClientListResponse] + deletedEndpoints []*endpoint.Endpoint + updatedEndpoints []*endpoint.Endpoint +} + +func newMockPrivateRecordSectsClient(recordSets []*privatedns.RecordSet) mockPrivateRecordSetsClient { + pagingHandler := azcoreruntime.PagingHandler[privatedns.RecordSetsClientListResponse]{ + More: func(resp privatedns.RecordSetsClientListResponse) bool { + return false + }, + Fetcher: func(context.Context, *privatedns.RecordSetsClientListResponse) (privatedns.RecordSetsClientListResponse, error) { + return privatedns.RecordSetsClientListResponse{ + RecordSetListResult: privatedns.RecordSetListResult{ + Value: recordSets, + }, + }, nil + }, + } + return mockPrivateRecordSetsClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockPrivateRecordSetsClient) NewListPager(resourceGroupName string, privateZoneName string, options *privatedns.RecordSetsClientListOptions) *azcoreruntime.Pager[privatedns.RecordSetsClientListResponse] { + return azcoreruntime.NewPager(client.pagingHandler) +} + +func (client *mockPrivateRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, options *privatedns.RecordSetsClientDeleteOptions) (privatedns.RecordSetsClientDeleteResponse, error) { + client.deletedEndpoints = append( + client.deletedEndpoints, + endpoint.NewEndpoint( + formatAzureDNSName(relativeRecordSetName, privateZoneName), + string(recordType), + "", + ), + ) + return privatedns.RecordSetsClientDeleteResponse{}, nil +} + +func (client *mockPrivateRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, options *privatedns.RecordSetsClientCreateOrUpdateOptions) (privatedns.RecordSetsClientCreateOrUpdateResponse, error) { + var ttl endpoint.TTL + if parameters.Properties.TTL != nil { + ttl = endpoint.TTL(*parameters.Properties.TTL) + } + client.updatedEndpoints = append( + client.updatedEndpoints, + endpoint.NewEndpointWithTTL( + formatAzureDNSName(relativeRecordSetName, privateZoneName), + string(recordType), + ttl, + extractAzurePrivateDNSTargets(¶meters)..., + ), + ) + return privatedns.RecordSetsClientCreateOrUpdateResponse{}, nil + //return parameters, nil +} + +func createMockPrivateZone(zone string, id string) *privatedns.PrivateZone { + return &privatedns.PrivateZone{ + ID: to.Ptr(id), + Name: to.Ptr(zone), + } +} + +func privateARecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + aRecords := make([]*privatedns.ARecord, len(values)) + for i, value := range values { + aRecords[i] = &privatedns.ARecord{ + IPv4Address: to.Ptr(value), + } + } + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, + } +} + +func privateAAAARecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + aaaaRecords := make([]*privatedns.AaaaRecord, len(values)) + for i, value := range values { + aaaaRecords[i] = &privatedns.AaaaRecord{ + IPv6Address: to.Ptr(value), + } + } + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + AaaaRecords: aaaaRecords, + } +} + +func privateCNameRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + CnameRecord: &privatedns.CnameRecord{ + Cname: to.Ptr(values[0]), + }, + } +} + +func privateMXRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + mxRecords := make([]*privatedns.MxRecord, len(values)) + for i, target := range values { + mxRecord, _ := parseMxTarget[privatedns.MxRecord](target) + mxRecords[i] = &mxRecord + } + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + MxRecords: mxRecords, + } +} + +func privateTxtRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*privatedns.TxtRecord{ + { + Value: []*string{&values[0]}, + }, + }, + } +} + +func privateOthersRecordSetPropertiesGetter(values []string, ttl int64) *privatedns.RecordSetProperties { + return &privatedns.RecordSetProperties{ + TTL: to.Ptr(ttl), + } +} + +func createPrivateMockRecordSet(name, recordType string, values ...string) *privatedns.RecordSet { + return createPrivateMockRecordSetMultiWithTTL(name, recordType, 0, values...) +} + +func createPrivateMockRecordSetWithTTL(name, recordType, value string, ttl int64) *privatedns.RecordSet { + return createPrivateMockRecordSetMultiWithTTL(name, recordType, ttl, value) +} + +func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) *privatedns.RecordSet { + var getterFunc func(values []string, ttl int64) *privatedns.RecordSetProperties + + switch recordType { + case endpoint.RecordTypeA: + getterFunc = privateARecordSetPropertiesGetter + case endpoint.RecordTypeAAAA: + getterFunc = privateAAAARecordSetPropertiesGetter + case endpoint.RecordTypeCNAME: + getterFunc = privateCNameRecordSetPropertiesGetter + case endpoint.RecordTypeMX: + getterFunc = privateMXRecordSetPropertiesGetter + case endpoint.RecordTypeTXT: + getterFunc = privateTxtRecordSetPropertiesGetter + default: + getterFunc = privateOthersRecordSetPropertiesGetter + } + return &privatedns.RecordSet{ + Name: to.Ptr(name), + Type: to.Ptr("Microsoft.Network/privateDnsZones/" + recordType), + Properties: getterFunc(values, ttl), + } +} + +// newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets +func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) { + zonesClient := newMockPrivateZonesClient(zones) + recordSetsClient := newMockPrivateRecordSectsClient(recordSets) + return newAzurePrivateDNSProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil +} + +func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider { + return &AzurePrivateDNSProvider{ + domainFilter: domainFilter, + zoneIDFilter: zoneIDFilter, + dryRun: dryRun, + resourceGroup: resourceGroup, + zonesClient: privateZonesClient, + recordSetsClient: privateRecordsClient, + } +} + +func TestAzurePrivateDNSRecord(t *testing.T) { + provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", + []*privatedns.PrivateZone{ + createMockPrivateZone("example.com", "/privateDnsZones/example.com"), + }, + []*privatedns.RecordSet{ + createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), + createPrivateMockRecordSet("@", endpoint.RecordTypeAAAA, "2001::123:123:123:122"), + createPrivateMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeAAAA, "2001::123:123:123:123", 3600), + createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createPrivateMockRecordSetWithTTL("mail", endpoint.RecordTypeMX, "10 example.com", 4000), + }) + if err != nil { + t.Fatal(err) + } + + actual, err := provider.Records(context.Background()) + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::123:123:123:122"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com"), + } + + validateAzureEndpoints(t, actual, expected) +} + +func TestAzurePrivateDNSMultiRecord(t *testing.T) { + provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", + []*privatedns.PrivateZone{ + createMockPrivateZone("example.com", "/privateDnsZones/example.com"), + }, + []*privatedns.RecordSet{ + createPrivateMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createPrivateMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createPrivateMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), + createPrivateMockRecordSet("@", endpoint.RecordTypeAAAA, "2001::123:123:123:122", "2001::234:234:234:233"), + createPrivateMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createPrivateMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), + createPrivateMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123", "2001::234:234:234:234"), + createPrivateMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createPrivateMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createPrivateMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), + }) + if err != nil { + t.Fatal(err) + } + + actual, err := provider.Records(context.Background()) + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::123:123:123:122", "2001::234:234:234:233"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123", "2001::234:234:234:234"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), + } + + validateAzureEndpoints(t, actual, expected) +} + +func TestAzurePrivateDNSApplyChanges(t *testing.T) { + recordsClient := mockPrivateRecordSetsClient{} + + testAzurePrivateDNSApplyChangesInternal(t, false, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, ""), + endpoint.NewEndpoint("deletedaaaa.example.com", endpoint.RecordTypeAAAA, ""), + endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, ""), + }) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4"), + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::1:2:3:4"), + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::1:2:3:4", "2001::1:2:3:5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "other.com"), + endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "5.6.7.8"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::5:6:7:8"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(recordTTL), "10 other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + }) +} + +func TestAzurePrivateDNSApplyChangesDryRun(t *testing.T) { + recordsClient := mockRecordSetsClient{} + + testAzureApplyChangesInternal(t, true, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{}) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{}) +} + +func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client PrivateRecordSetsClient) { + zones := []*privatedns.PrivateZone{ + createMockPrivateZone("example.com", "/privateDnsZones/example.com"), + createMockPrivateZone("other.com", "/privateDnsZones/other.com"), + } + zonesClient := newMockPrivateZonesClient(zones) + + provider := newAzurePrivateDNSProvider( + endpoint.NewDomainFilter([]string{""}), + provider.NewZoneIDFilter([]string{""}), + dryRun, + "group", + &zonesClient, + client, + ) + + createRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:5", "2001::1:2:3:4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeAAAA, "2001::5:6:7:8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeAAAA, "2001::4:4:4:4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeMX, "10 other.com"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeTXT, "tag"), + } + + currentRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("old.nope.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, "20 foo.other.com"), + } + updatedRecords := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeAAAA, "2001::222:111:222:111"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + } + + deleteRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, "111.222.111.222"), + endpoint.NewEndpoint("deletedaaaa.example.com", endpoint.RecordTypeAAAA, "2001::111:222:111:222"), + endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("deleted.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpoint("deleted.nope.com", endpoint.RecordTypeAAAA, "2001::222:111:222:111"), + } + + changes := &plan.Changes{ + Create: createRecords, + UpdateNew: updatedRecords, + UpdateOld: currentRecords, + Delete: deleteRecords, + } + + if err := provider.ApplyChanges(context.Background(), changes); err != nil { + t.Fatal(err) + } +} diff --git a/internal/external-dns/azure/azure_test.go b/internal/external-dns/azure/azure_test.go new file mode 100644 index 00000000..4fd9fa8f --- /dev/null +++ b/internal/external-dns/azure/azure_test.go @@ -0,0 +1,558 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + "testing" + + azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" + "github.com/stretchr/testify/assert" + + "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/internal/testutils" + "sigs.k8s.io/external-dns/plan" + "sigs.k8s.io/external-dns/provider" +) + +// mockZonesClient implements the methods of the Azure DNS Zones Client which are used in the Azure Provider +// and returns static results which are defined per test +type mockZonesClient struct { + pagingHandler azcoreruntime.PagingHandler[dns.ZonesClientListByResourceGroupResponse] +} + +func newMockZonesClient(zones []*dns.Zone) mockZonesClient { + pagingHandler := azcoreruntime.PagingHandler[dns.ZonesClientListByResourceGroupResponse]{ + More: func(resp dns.ZonesClientListByResourceGroupResponse) bool { + return false + }, + Fetcher: func(context.Context, *dns.ZonesClientListByResourceGroupResponse) (dns.ZonesClientListByResourceGroupResponse, error) { + return dns.ZonesClientListByResourceGroupResponse{ + ZoneListResult: dns.ZoneListResult{ + Value: zones, + }, + }, nil + }, + } + return mockZonesClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockZonesClient) NewListByResourceGroupPager(resourceGroupName string, options *dns.ZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[dns.ZonesClientListByResourceGroupResponse] { + return azcoreruntime.NewPager(client.pagingHandler) +} + +// mockZonesClient implements the methods of the Azure DNS RecordSet Client which are used in the Azure Provider +// and returns static results which are defined per test +type mockRecordSetsClient struct { + pagingHandler azcoreruntime.PagingHandler[dns.RecordSetsClientListAllByDNSZoneResponse] + deletedEndpoints []*endpoint.Endpoint + updatedEndpoints []*endpoint.Endpoint +} + +func newMockRecordSetsClient(recordSets []*dns.RecordSet) mockRecordSetsClient { + pagingHandler := azcoreruntime.PagingHandler[dns.RecordSetsClientListAllByDNSZoneResponse]{ + More: func(resp dns.RecordSetsClientListAllByDNSZoneResponse) bool { + return false + }, + Fetcher: func(context.Context, *dns.RecordSetsClientListAllByDNSZoneResponse) (dns.RecordSetsClientListAllByDNSZoneResponse, error) { + return dns.RecordSetsClientListAllByDNSZoneResponse{ + RecordSetListResult: dns.RecordSetListResult{ + Value: recordSets, + }, + }, nil + }, + } + return mockRecordSetsClient{ + pagingHandler: pagingHandler, + } +} + +func (client *mockRecordSetsClient) NewListAllByDNSZonePager(resourceGroupName string, zoneName string, options *dns.RecordSetsClientListAllByDNSZoneOptions) *azcoreruntime.Pager[dns.RecordSetsClientListAllByDNSZoneResponse] { + return azcoreruntime.NewPager(client.pagingHandler) +} + +func (client *mockRecordSetsClient) Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, options *dns.RecordSetsClientDeleteOptions) (dns.RecordSetsClientDeleteResponse, error) { + client.deletedEndpoints = append( + client.deletedEndpoints, + endpoint.NewEndpoint( + formatAzureDNSName(relativeRecordSetName, zoneName), + string(recordType), + "", + ), + ) + return dns.RecordSetsClientDeleteResponse{}, nil +} + +func (client *mockRecordSetsClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, options *dns.RecordSetsClientCreateOrUpdateOptions) (dns.RecordSetsClientCreateOrUpdateResponse, error) { + var ttl endpoint.TTL + if parameters.Properties.TTL != nil { + ttl = endpoint.TTL(*parameters.Properties.TTL) + } + client.updatedEndpoints = append( + client.updatedEndpoints, + endpoint.NewEndpointWithTTL( + formatAzureDNSName(relativeRecordSetName, zoneName), + string(recordType), + ttl, + extractAzureTargets(¶meters)..., + ), + ) + return dns.RecordSetsClientCreateOrUpdateResponse{}, nil +} + +func createMockZone(zone string, id string) *dns.Zone { + return &dns.Zone{ + ID: to.Ptr(id), + Name: to.Ptr(zone), + } +} + +func aRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + aRecords := make([]*dns.ARecord, len(values)) + for i, value := range values { + aRecords[i] = &dns.ARecord{ + IPv4Address: to.Ptr(value), + } + } + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + ARecords: aRecords, + } +} + +func aaaaRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + aaaaRecords := make([]*dns.AaaaRecord, len(values)) + for i, value := range values { + aaaaRecords[i] = &dns.AaaaRecord{ + IPv6Address: to.Ptr(value), + } + } + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + AaaaRecords: aaaaRecords, + } +} + +func cNameRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + CnameRecord: &dns.CnameRecord{ + Cname: to.Ptr(values[0]), + }, + } +} + +func mxRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + mxRecords := make([]*dns.MxRecord, len(values)) + for i, target := range values { + mxRecord, _ := parseMxTarget[dns.MxRecord](target) + mxRecords[i] = &mxRecord + } + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + MxRecords: mxRecords, + } +} + +func txtRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + TxtRecords: []*dns.TxtRecord{ + { + Value: []*string{to.Ptr(values[0])}, + }, + }, + } +} + +func othersRecordSetPropertiesGetter(values []string, ttl int64) *dns.RecordSetProperties { + return &dns.RecordSetProperties{ + TTL: to.Ptr(ttl), + } +} + +func createMockRecordSet(name, recordType string, values ...string) *dns.RecordSet { + return createMockRecordSetMultiWithTTL(name, recordType, 0, values...) +} + +func createMockRecordSetWithTTL(name, recordType, value string, ttl int64) *dns.RecordSet { + return createMockRecordSetMultiWithTTL(name, recordType, ttl, value) +} + +func createMockRecordSetMultiWithTTL(name, recordType string, ttl int64, values ...string) *dns.RecordSet { + var getterFunc func(values []string, ttl int64) *dns.RecordSetProperties + + switch recordType { + case endpoint.RecordTypeA: + getterFunc = aRecordSetPropertiesGetter + case endpoint.RecordTypeAAAA: + getterFunc = aaaaRecordSetPropertiesGetter + case endpoint.RecordTypeCNAME: + getterFunc = cNameRecordSetPropertiesGetter + case endpoint.RecordTypeMX: + getterFunc = mxRecordSetPropertiesGetter + case endpoint.RecordTypeTXT: + getterFunc = txtRecordSetPropertiesGetter + default: + getterFunc = othersRecordSetPropertiesGetter + } + return &dns.RecordSet{ + Name: to.Ptr(name), + Type: to.Ptr("Microsoft.Network/dnszones/" + recordType), + Properties: getterFunc(values, ttl), + } +} + +// newMockedAzureProvider creates an AzureProvider comprising the mocked clients for zones and recordsets +func newMockedAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zones []*dns.Zone, recordSets []*dns.RecordSet) (*AzureProvider, error) { + zonesClient := newMockZonesClient(zones) + recordSetsClient := newMockRecordSetsClient(recordSets) + return newAzureProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, userAssignedIdentityClientID, &zonesClient, &recordSetsClient), nil +} + +func newAzureProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, userAssignedIdentityClientID string, zonesClient ZonesClient, recordsClient RecordSetsClient) *AzureProvider { + return &AzureProvider{ + domainFilter: domainFilter, + zoneNameFilter: zoneNameFilter, + zoneIDFilter: zoneIDFilter, + dryRun: dryRun, + resourceGroup: resourceGroup, + userAssignedIdentityClientID: userAssignedIdentityClientID, + zonesClient: zonesClient, + recordSetsClient: recordsClient, + } +} + +func validateAzureEndpoints(t *testing.T, endpoints []*endpoint.Endpoint, expected []*endpoint.Endpoint) { + assert.True(t, testutils.SameEndpoints(endpoints, expected), "expected and actual endpoints don't match. %s:%s", endpoints, expected) +} + +func TestAzureRecord(t *testing.T) { + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + []*dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + }, + []*dns.RecordSet{ + createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), + createMockRecordSet("@", endpoint.RecordTypeAAAA, "2001::123:123:123:122"), + createMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeAAAA, "2001::123:123:123:123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com"), + }) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + actual, err := provider.Records(ctx) + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::123:123:123:122"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com"), + } + + validateAzureEndpoints(t, actual, expected) +} + +func TestAzureMultiRecord(t *testing.T) { + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + []*dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + }, + []*dns.RecordSet{ + createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), + createMockRecordSet("@", endpoint.RecordTypeAAAA, "2001::123:123:123:122", "2001::234:234:234:233"), + createMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), + createMockRecordSetMultiWithTTL("nginx", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123", "2001::234:234:234:234"), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + createMockRecordSetMultiWithTTL("mail", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), + }) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + actual, err := provider.Records(ctx) + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "123.123.123.122", "234.234.234.233"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::123:123:123:122", "2001::234:234:234:233"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123", "234.234.234.234"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeAAAA, 3600, "2001::123:123:123:123", "2001::234:234:234:234"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("hack.example.com", endpoint.RecordTypeCNAME, 10, "hack.azurewebsites.net"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, 4000, "10 example.com", "20 backup.example.com"), + } + + validateAzureEndpoints(t, actual, expected) +} + +func TestAzureApplyChanges(t *testing.T) { + recordsClient := mockRecordSetsClient{} + + testAzureApplyChangesInternal(t, false, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, ""), + endpoint.NewEndpoint("deletedaaaa.example.com", endpoint.RecordTypeAAAA, ""), + endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, ""), + }) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4"), + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::1:2:3:4"), + endpoint.NewEndpointWithTTL("example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::1:2:3:4", "2001::1:2:3:5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeCNAME, endpoint.TTL(recordTTL), "other.com"), + endpoint.NewEndpointWithTTL("bar.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "5.6.7.8"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::5:6:7:8"), + endpoint.NewEndpointWithTTL("other.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeMX, endpoint.TTL(recordTTL), "10 other.com"), + endpoint.NewEndpointWithTTL("mail.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + }) +} + +func TestAzureApplyChangesDryRun(t *testing.T) { + recordsClient := mockRecordSetsClient{} + + testAzureApplyChangesInternal(t, true, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{}) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{}) +} + +func testAzureApplyChangesInternal(t *testing.T, dryRun bool, client RecordSetsClient) { + zones := []*dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + createMockZone("other.com", "/dnszones/other.com"), + } + zonesClient := newMockZonesClient(zones) + + provider := newAzureProvider( + endpoint.NewDomainFilter([]string{""}), + endpoint.NewDomainFilter([]string{""}), + provider.NewZoneIDFilter([]string{""}), + dryRun, + "group", + "", + &zonesClient, + client, + ) + + createRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:5", "2001::1:2:3:4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeAAAA, "2001::5:6:7:8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeAAAA, "2001::4:4:4:4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeMX, "10 other.com"), + endpoint.NewEndpoint("mail.example.com", endpoint.RecordTypeTXT, "tag"), + } + + currentRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("old.example.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldcname.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("old.nope.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldmail.example.com", endpoint.RecordTypeMX, "20 foo.other.com"), + } + updatedRecords := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpoint("new.nope.com", endpoint.RecordTypeAAAA, "2001::222:111:222:111"), + endpoint.NewEndpointWithTTL("newmail.example.com", endpoint.RecordTypeMX, 7200, "40 bar.other.com"), + } + + deleteRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.example.com", endpoint.RecordTypeA, "111.222.111.222"), + endpoint.NewEndpoint("deletedaaaa.example.com", endpoint.RecordTypeAAAA, "2001::111:222:111:222"), + endpoint.NewEndpoint("deletedcname.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("deleted.nope.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpoint("deleted.nope.com", endpoint.RecordTypeAAAA, "2001::222:111:222:111"), + } + + changes := &plan.Changes{ + Create: createRecords, + UpdateNew: updatedRecords, + UpdateOld: currentRecords, + Delete: deleteRecords, + } + + if err := provider.ApplyChanges(context.Background(), changes); err != nil { + t.Fatal(err) + } +} + +func TestAzureNameFilter(t *testing.T) { + provider, err := newMockedAzureProvider(endpoint.NewDomainFilter([]string{"nginx.example.com"}), endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s", "", + []*dns.Zone{ + createMockZone("example.com", "/dnszones/example.com"), + }, + + []*dns.RecordSet{ + createMockRecordSet("@", "NS", "ns1-03.azure-dns.com."), + createMockRecordSet("@", "SOA", "Email: azuredns-hostmaster.microsoft.com"), + createMockRecordSet("@", endpoint.RecordTypeA, "123.123.123.122"), + createMockRecordSet("@", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default"), + createMockRecordSetWithTTL("test.nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeA, "123.123.123.123", 3600), + createMockRecordSetWithTTL("nginx", endpoint.RecordTypeTXT, "heritage=external-dns,external-dns/owner=default", recordTTL), + createMockRecordSetWithTTL("mail.nginx", endpoint.RecordTypeMX, "20 example.com", recordTTL), + createMockRecordSetWithTTL("hack", endpoint.RecordTypeCNAME, "hack.azurewebsites.net", 10), + }) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + actual, err := provider.Records(ctx) + if err != nil { + t.Fatal(err) + } + expected := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("test.nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeA, 3600, "123.123.123.123"), + endpoint.NewEndpointWithTTL("nginx.example.com", endpoint.RecordTypeTXT, recordTTL, "heritage=external-dns,external-dns/owner=default"), + endpoint.NewEndpointWithTTL("mail.nginx.example.com", endpoint.RecordTypeMX, recordTTL, "20 example.com"), + } + + validateAzureEndpoints(t, actual, expected) +} + +func TestAzureApplyChangesZoneName(t *testing.T) { + recordsClient := mockRecordSetsClient{} + + testAzureApplyChangesInternalZoneName(t, false, &recordsClient) + + validateAzureEndpoints(t, recordsClient.deletedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, ""), + endpoint.NewEndpoint("deletedaaaa.foo.example.com", endpoint.RecordTypeAAAA, ""), + endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, ""), + }) + + validateAzureEndpoints(t, recordsClient.updatedEndpoints, []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeA, endpoint.TTL(recordTTL), "1.2.3.4", "1.2.3.5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeAAAA, endpoint.TTL(recordTTL), "2001::1:2:3:4", "2001::1:2:3:5"), + endpoint.NewEndpointWithTTL("foo.example.com", endpoint.RecordTypeTXT, endpoint.TTL(recordTTL), "tag"), + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + }) +} + +func testAzureApplyChangesInternalZoneName(t *testing.T, dryRun bool, client RecordSetsClient) { + zonesClient := newMockZonesClient([]*dns.Zone{createMockZone("example.com", "/dnszones/example.com")}) + + provider := newAzureProvider( + endpoint.NewDomainFilter([]string{"foo.example.com"}), + endpoint.NewDomainFilter([]string{"example.com"}), + provider.NewZoneIDFilter([]string{""}), + dryRun, + "group", + "", + &zonesClient, + client, + ) + + createRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("example.com", endpoint.RecordTypeA, "1.2.3.4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:4"), + endpoint.NewEndpoint("example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeA, "1.2.3.5", "1.2.3.4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeAAAA, "2001::1:2:3:5", "2001::1:2:3:4"), + endpoint.NewEndpoint("foo.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("bar.example.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeA, "5.6.7.8"), + endpoint.NewEndpoint("other.com", endpoint.RecordTypeTXT, "tag"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeA, "4.4.4.4"), + endpoint.NewEndpoint("nope.com", endpoint.RecordTypeTXT, "tag"), + } + + currentRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("old.foo.example.com", endpoint.RecordTypeA, "121.212.121.212"), + endpoint.NewEndpoint("oldcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("old.nope.example.com", endpoint.RecordTypeA, "121.212.121.212"), + } + updatedRecords := []*endpoint.Endpoint{ + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeA, 3600, "111.222.111.222"), + endpoint.NewEndpointWithTTL("new.foo.example.com", endpoint.RecordTypeAAAA, 3600, "2001::111:222:111:222"), + endpoint.NewEndpointWithTTL("newcname.foo.example.com", endpoint.RecordTypeCNAME, 10, "other.com"), + endpoint.NewEndpoint("new.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"), + endpoint.NewEndpoint("new.nope.example.com", endpoint.RecordTypeAAAA, "2001::222:111:222:111"), + } + + deleteRecords := []*endpoint.Endpoint{ + endpoint.NewEndpoint("deleted.foo.example.com", endpoint.RecordTypeA, "111.222.111.222"), + endpoint.NewEndpoint("deletedaaaa.foo.example.com", endpoint.RecordTypeAAAA, "2001::111:222:111:222"), + endpoint.NewEndpoint("deletedcname.foo.example.com", endpoint.RecordTypeCNAME, "other.com"), + endpoint.NewEndpoint("deleted.nope.example.com", endpoint.RecordTypeA, "222.111.222.111"), + } + + changes := &plan.Changes{ + Create: createRecords, + UpdateNew: updatedRecords, + UpdateOld: currentRecords, + Delete: deleteRecords, + } + + if err := provider.ApplyChanges(context.Background(), changes); err != nil { + t.Fatal(err) + } +} diff --git a/internal/external-dns/azure/common.go b/internal/external-dns/azure/common.go new file mode 100644 index 00000000..688a0a57 --- /dev/null +++ b/internal/external-dns/azure/common.go @@ -0,0 +1,47 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//nolint:staticcheck // Required due to the current dependency on a deprecated version of azure-sdk-for-go +package azure + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" +) + +// Helper function (shared with test code) +func parseMxTarget[T dns.MxRecord | privatedns.MxRecord](mxTarget string) (T, error) { + targetParts := strings.SplitN(mxTarget, " ", 2) + if len(targetParts) != 2 { + return T{}, fmt.Errorf("mx target needs to be of form '10 example.com'") + } + + preferenceRaw, exchange := targetParts[0], targetParts[1] + preference, err := strconv.ParseInt(preferenceRaw, 10, 32) + if err != nil { + return T{}, fmt.Errorf("invalid preference specified") + } + + return T{ + Preference: to.Ptr(int32(preference)), + Exchange: to.Ptr(exchange), + }, nil +} diff --git a/internal/external-dns/azure/common_test.go b/internal/external-dns/azure/common_test.go new file mode 100644 index 00000000..b85fb5f4 --- /dev/null +++ b/internal/external-dns/azure/common_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" + privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + + "github.com/stretchr/testify/assert" +) + +func Test_parseMxTarget(t *testing.T) { + type testCase[T interface { + dns.MxRecord | privatedns.MxRecord + }] struct { + name string + args string + want T + wantErr assert.ErrorAssertionFunc + } + + tests := []testCase[dns.MxRecord]{ + { + name: "valid mx target", + args: "10 example.com", + want: dns.MxRecord{ + Preference: to.Ptr(int32(10)), + Exchange: to.Ptr("example.com"), + }, + wantErr: assert.NoError, + }, + { + name: "valid mx target with a subdomain", + args: "99 foo-bar.example.com", + want: dns.MxRecord{ + Preference: to.Ptr(int32(99)), + Exchange: to.Ptr("foo-bar.example.com"), + }, + wantErr: assert.NoError, + }, + { + name: "invalid mx target with misplaced preference and exchange", + args: "example.com 10", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + { + name: "invalid mx target without preference", + args: "example.com", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + { + name: "invalid mx target with non numeric preference", + args: "aa example.com", + want: dns.MxRecord{}, + wantErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMxTarget[dns.MxRecord](tt.args) + if !tt.wantErr(t, err, fmt.Sprintf("parseMxTarget(%v)", tt.args)) { + return + } + assert.Equalf(t, tt.want, got, "parseMxTarget(%v)", tt.args) + }) + } +} diff --git a/internal/external-dns/azure/config.go b/internal/external-dns/azure/config.go new file mode 100644 index 00000000..c148baf8 --- /dev/null +++ b/internal/external-dns/azure/config.go @@ -0,0 +1,154 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "fmt" + "os" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" +) + +// config represents common config items for Azure DNS and Azure Private DNS +type config struct { + Cloud string `json:"cloud" yaml:"cloud"` + TenantID string `json:"tenantId" yaml:"tenantId"` + SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"` + ResourceGroup string `json:"resourceGroup" yaml:"resourceGroup"` + Location string `json:"location" yaml:"location"` + ClientID string `json:"aadClientId" yaml:"aadClientId"` + ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"` + UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"` + UseWorkloadIdentityExtension bool `json:"useWorkloadIdentityExtension" yaml:"useWorkloadIdentityExtension"` + UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"` +} + +func getConfig(configFile, resourceGroup, userAssignedIdentityClientID string) (*config, error) { + contents, err := os.ReadFile(configFile) + if err != nil { + return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) + } + cfg := &config{} + err = yaml.Unmarshal(contents, &cfg) + if err != nil { + return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err) + } + + // If a resource group was given, override what was present in the config file + if resourceGroup != "" { + cfg.ResourceGroup = resourceGroup + } + // If userAssignedIdentityClientID is provided explicitly, override existing one in config file + if userAssignedIdentityClientID != "" { + cfg.UserAssignedIdentityID = userAssignedIdentityClientID + } + return cfg, nil +} + +// getAccessToken retrieves Azure API access token. +func getCredentials(cfg config) (azcore.TokenCredential, *arm.ClientOptions, error) { + cloudCfg, err := getCloudConfiguration(cfg.Cloud) + if err != nil { + return nil, nil, fmt.Errorf("failed to get cloud configuration: %w", err) + } + clientOpts := azcore.ClientOptions{ + Cloud: cloudCfg, + } + armClientOpts := &arm.ClientOptions{ + ClientOptions: clientOpts, + } + + // Try to retrieve token with service principal credentials. + // Try to use service principal first, some AKS clusters are in an intermediate state that `UseManagedIdentityExtension` is `true` + // and service principal exists. In this case, we still want to use service principal to authenticate. + if len(cfg.ClientID) > 0 && + len(cfg.ClientSecret) > 0 && + // due to some historical reason, for pure MSI cluster, + // they will use "msi" as placeholder in azure.json. + // In this case, we shouldn't try to use SPN to authenticate. + !strings.EqualFold(cfg.ClientID, "msi") && + !strings.EqualFold(cfg.ClientSecret, "msi") { + log.Info("Using client_id+client_secret to retrieve access token for Azure API.") + opts := &azidentity.ClientSecretCredentialOptions{ + ClientOptions: clientOpts, + } + cred, err := azidentity.NewClientSecretCredential(cfg.TenantID, cfg.ClientID, cfg.ClientSecret, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to create service principal token: %w", err) + } + return cred, armClientOpts, nil + } + + // Try to retrieve token with Workload Identity. + if cfg.UseWorkloadIdentityExtension { + log.Info("Using workload identity extension to retrieve access token for Azure API.") + + wiOpt := azidentity.WorkloadIdentityCredentialOptions{ + ClientOptions: clientOpts, + // In a standard scenario, Client ID and Tenant ID are expected to be read from environment variables. + // Though, in certain cases, it might be important to have an option to override those (e.g. when AZURE_TENANT_ID is not set + // through a webhook or azure.workload.identity/client-id service account annotation is absent). When any of those values are + // empty in our config, they will automatically be read from environment variables by azidentity + TenantID: cfg.TenantID, + ClientID: cfg.ClientID, + } + + cred, err := azidentity.NewWorkloadIdentityCredential(&wiOpt) + if err != nil { + return nil, nil, fmt.Errorf("failed to create a workload identity token: %w", err) + } + + return cred, armClientOpts, nil + } + + // Try to retrieve token with MSI. + if cfg.UseManagedIdentityExtension { + log.Info("Using managed identity extension to retrieve access token for Azure API.") + msiOpt := azidentity.ManagedIdentityCredentialOptions{ + ClientOptions: clientOpts, + } + if cfg.UserAssignedIdentityID != "" { + msiOpt.ID = azidentity.ClientID(cfg.UserAssignedIdentityID) + } + cred, err := azidentity.NewManagedIdentityCredential(&msiOpt) + if err != nil { + return nil, nil, fmt.Errorf("failed to create the managed service identity token: %w", err) + } + return cred, armClientOpts, nil + } + + return nil, nil, fmt.Errorf("no credentials provided for Azure API") +} + +func getCloudConfiguration(name string) (cloud.Configuration, error) { + name = strings.ToUpper(name) + switch name { + case "AZURECLOUD", "AZUREPUBLICCLOUD", "": + return cloud.AzurePublic, nil + case "AZUREUSGOVERNMENT", "AZUREUSGOVERNMENTCLOUD": + return cloud.AzureGovernment, nil + case "AZURECHINACLOUD": + return cloud.AzureChina, nil + } + return cloud.Configuration{}, fmt.Errorf("unknown cloud name: %s", name) +} diff --git a/internal/external-dns/azure/config_test.go b/internal/external-dns/azure/config_test.go new file mode 100644 index 00000000..7551fa51 --- /dev/null +++ b/internal/external-dns/azure/config_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" +) + +func TestGetCloudConfiguration(t *testing.T) { + tests := map[string]struct { + cloudName string + expected cloud.Configuration + }{ + "AzureChinaCloud": {"AzureChinaCloud", cloud.AzureChina}, + "AzurePublicCloud": {"", cloud.AzurePublic}, + "AzureUSGovernment": {"AzureUSGovernmentCloud", cloud.AzureGovernment}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cloudCfg, err := getCloudConfiguration(test.cloudName) + if err != nil { + t.Errorf("got unexpected err %v", err) + } + if cloudCfg.ActiveDirectoryAuthorityHost != test.expected.ActiveDirectoryAuthorityHost { + t.Errorf("got %v, want %v", cloudCfg, test.expected) + } + }) + } +}