Skip to content

Commit

Permalink
SDK-less implementation using token
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jun 16, 2024
1 parent 2dc61b8 commit 8e39ff0
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 92 deletions.
22 changes: 2 additions & 20 deletions docs/azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

- `"domain"`
- `"host"`
- `"tenant_id"`
- `"client_id"`
- `"client_secret"`
- `"token"`
- `"subscription_id"` found in the properties section of Azure DNS
- `"resource_group_name"` found in the properties section of Azure DNS

Expand All @@ -45,20 +43,4 @@ Thanks to @danimart1991 for describing the following steps!
- The name or URL `AnyNameOrUrl` for the query below **TODO**
- `subscription_id`
- `resource_group_name`
1. In the Azure Console (inside the portal), run:

```sh
az ad sp create-for-rbac -n "$AnyNameOrUrl" --scopes "/subscriptions/$subscription_id/resourceGroups/$resource_group_name/providers/Microsoft.Network/dnszones/$zone_name"
```

This gives you the rest of the parameters:

```json
{
"appId": "{app_id/client_id}",
"displayName": "not important",
"name": "not important",
"password": "{app_password}",
"tenant": "not important"
}
```
1. Get your token, see [this article](https://mauridb.medium.com/calling-azure-rest-api-via-curl-eb10a06127)
223 changes: 183 additions & 40 deletions internal/provider/providers/azure/api.go
Original file line number Diff line number Diff line change
@@ -1,77 +1,220 @@
package azure

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/netip"
"net/url"

"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
"github.com/qdm12/ddns-updater/internal/provider/constants"
"github.com/qdm12/ddns-updater/internal/provider/errors"
"github.com/qdm12/ddns-updater/internal/provider/headers"
"github.com/qdm12/ddns-updater/internal/provider/utils"
)

func (p *Provider) createClient() (client *armdns.RecordSetsClient, err error) {
credential, err := azidentity.NewClientSecretCredential(p.tenantID, p.clientID, p.clientSecret, nil)
type rrSet struct {
ID string `json:"id"`
Etag string `json:"etag"`
Name string `json:"name"`
Type string `json:"type"`
Properties struct {
Metadata map[string]string `json:"metadata"`
TTL uint32 `json:"TTL"`
FQDN string `json:"fqdn"`
ARecords []arecord `json:"ARecords"`
AAAARecords []aaaarecord `json:"AAAARecords"`
} `json:"properties"`
}

type arecord struct {
IPv4Address string `json:"ipv4Address"`
}

type aaaarecord struct {
IPv6Address string `json:"ipv6Address"`
}

func makeURL(subscriptionID, resourceGroupName, domain, recordType, host string) string {
path := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/dnsZones/%s/%s/%s",
subscriptionID, resourceGroupName, domain, recordType, host)
values := url.Values{}
values.Set("api-version", "2018-05-01")
u := url.URL{
Scheme: "https",
Host: "management.azure.com",
Path: path,
RawQuery: values.Encode(),
}
return u.String()
}

func (p *Provider) getRecordSet(ctx context.Context, client *http.Client,
recordType string) (data rrSet, err error) {
url := makeURL(p.subscriptionID, p.resourceGroupName, p.domain, recordType, p.host)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("creating client secret credential: %w", err)
return data, err
}
headers.SetUserAgent(request)
headers.SetAuthBearer(request, p.token)
headers.SetAccept(request, "application/json")

client, err = armdns.NewRecordSetsClient(p.subscriptionID, credential, nil)
response, err := client.Do(request)
if err != nil {
return nil, fmt.Errorf("creating record sets client: %w", err)
return data, err
}

return client, nil
}
switch response.StatusCode {
case http.StatusOK:
case http.StatusNotFound:
return data, fmt.Errorf("%w: %s %s",
errors.ErrRecordNotFound, p.host, recordType)
default:
message := decodeError(response.Body)
_ = response.Body.Close()
return data, fmt.Errorf("%w: %d: %s", errors.ErrHTTPStatusNotValid,
response.StatusCode, message)
}

func (p *Provider) getRecordSet(ctx context.Context, client *armdns.RecordSetsClient,
recordType armdns.RecordType) (response armdns.RecordSetsClientGetResponse, err error) {
return client.Get(ctx, p.resourceGroupName, p.domain, p.host, recordType, nil)
decoder := json.NewDecoder(response.Body)
err = decoder.Decode(&data)
_ = response.Body.Close()
if err != nil {
return data, fmt.Errorf("JSON decoding response: %w", err)
}

return data, nil
}

func (p *Provider) createRecordSet(ctx context.Context, client *armdns.RecordSetsClient,
func (p *Provider) createRecordSet(ctx context.Context, client *http.Client,
ip netip.Addr) (err error) {
rrSet := armdns.RecordSet{Properties: &armdns.RecordSetProperties{}}
recordType := armdns.RecordTypeA
var data rrSet
recordType := constants.A
if ip.Is4() {
rrSet.Properties.ARecords = []*armdns.ARecord{{IPv4Address: ptrTo(ip.String())}}
data.Properties.ARecords = []arecord{{IPv4Address: ip.String()}}
} else {
recordType = armdns.RecordTypeAAAA
rrSet.Properties.AaaaRecords = []*armdns.AaaaRecord{{IPv6Address: ptrTo(ip.String())}}
recordType = constants.AAAA
data.Properties.AAAARecords = []aaaarecord{{IPv6Address: ip.String()}}
}

buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
err = encoder.Encode(data)
if err != nil {
return fmt.Errorf("JSON encoding request body: %w", err)
}

url := makeURL(p.subscriptionID, p.resourceGroupName, p.domain, recordType, p.host)

request, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, buffer)
if err != nil {
return err
}
_, err = client.CreateOrUpdate(ctx, p.resourceGroupName, p.domain,
p.host, recordType, rrSet, nil)
headers.SetUserAgent(request)
headers.SetAuthBearer(request, p.token)
headers.SetContentType(request, "application/json")
headers.SetAccept(request, "application/json")

response, err := client.Do(request)
if err != nil {
return fmt.Errorf("creating record set: %w", err)
return err
}

if response.StatusCode != http.StatusOK {
message := decodeError(response.Body)
_ = response.Body.Close()
return fmt.Errorf("%w: %d: %s", errors.ErrHTTPStatusNotValid,
response.StatusCode, message)
}

err = response.Body.Close()
if err != nil {
return fmt.Errorf("closing response body: %w", err)
}

return nil
}

func (p *Provider) updateRecordSet(ctx context.Context, client *armdns.RecordSetsClient,
response armdns.RecordSetsClientGetResponse, ip netip.Addr) (err error) {
properties := response.Properties
recordType := armdns.RecordTypeA
func (p *Provider) updateRecordSet(ctx context.Context, client *http.Client,
data rrSet, ip netip.Addr) (err error) {
recordType := constants.A
if ip.Is4() {
if len(properties.ARecords) == 0 {
properties.ARecords = make([]*armdns.ARecord, 1)
if len(data.Properties.ARecords) == 0 {
data.Properties.ARecords = make([]arecord, 1)
}
for i := range properties.ARecords {
properties.ARecords[i].IPv4Address = ptrTo(ip.String())
for i := range data.Properties.ARecords {
data.Properties.ARecords[i].IPv4Address = ip.String()
}
data.Properties.ARecords = []arecord{{IPv4Address: ip.String()}}
} else {
recordType = armdns.RecordTypeAAAA
if len(properties.AaaaRecords) == 0 {
properties.AaaaRecords = make([]*armdns.AaaaRecord, 1)
recordType = constants.AAAA
if len(data.Properties.AAAARecords) == 0 {
data.Properties.AAAARecords = make([]aaaarecord, 1)
}
for i := range properties.AaaaRecords {
properties.AaaaRecords[i].IPv6Address = ptrTo(ip.String())
for i := range data.Properties.AAAARecords {
data.Properties.AAAARecords[i].IPv6Address = ip.String()
}
}
rrSet := armdns.RecordSet{
Etag: response.Etag,
Properties: properties,

buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
err = encoder.Encode(data)
if err != nil {
return fmt.Errorf("JSON encoding request body: %w", err)
}
url := makeURL(p.subscriptionID, p.resourceGroupName, p.domain, recordType, p.host)
request, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, buffer)
if err != nil {
return err
}
headers.SetUserAgent(request)
headers.SetAuthBearer(request, p.token)
headers.SetContentType(request, "application/json")
headers.SetAccept(request, "application/json")

response, err := client.Do(request)
if err != nil {
return err
}

if response.StatusCode != http.StatusOK {
message := decodeError(response.Body)
_ = response.Body.Close()
return fmt.Errorf("%w: %d: %s", errors.ErrHTTPStatusNotValid,
response.StatusCode, message)
}

err = response.Body.Close()
if err != nil {
return fmt.Errorf("closing response body: %w", err)
}

_, err = client.CreateOrUpdate(ctx, p.resourceGroupName, p.domain,
p.host, recordType, rrSet, nil)
return err
return nil
}

func decodeError(body io.ReadCloser) (message string) {
type cloudErrorBody struct {
Code string `json:"code"`
Message string `json:"message"`
Target string `json:"target"`
Details []cloudErrorBody `json:"details"`
}
var errorBody struct {
Error cloudErrorBody `json:"error"`
}
b, err := io.ReadAll(body)
if err != nil {
return err.Error()
}
err = json.Unmarshal(b, &errorBody)
_ = body.Close()
if err != nil {
return utils.ToSingleLine(string(b))
}
return fmt.Sprintf("%s: %s (target: %s)",
errorBody.Error.Code, errorBody.Error.Message, errorBody.Error.Target)
}
43 changes: 11 additions & 32 deletions internal/provider/providers/azure/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"net/http"
"net/netip"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
"github.com/qdm12/ddns-updater/internal/models"
"github.com/qdm12/ddns-updater/internal/provider/constants"
ddnserrors "github.com/qdm12/ddns-updater/internal/provider/errors"
Expand All @@ -22,17 +20,13 @@ type Provider struct {
host string // aka relativeRecordSetName
ipVersion ipversion.IPVersion
ipv6Suffix netip.Prefix
tenantID string
clientID string
clientSecret string
token string
subscriptionID string
resourceGroupName string
}

type settings struct {
TenantID string `json:"tenant_id"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Token string `json:"token"`
SubscriptionID string `json:"subscription_id"`
ResourceGroupName string `json:"resource_group_name"`
}
Expand All @@ -54,9 +48,7 @@ func New(data json.RawMessage, domain, host string,
host: host,
ipVersion: ipVersion,
ipv6Suffix: ipv6Suffix,
tenantID: providerSpecificSettings.TenantID,
clientID: providerSpecificSettings.ClientID,
clientSecret: providerSpecificSettings.ClientSecret,
token: providerSpecificSettings.Token,
subscriptionID: providerSpecificSettings.SubscriptionID,
resourceGroupName: providerSpecificSettings.ResourceGroupName,
}, nil
Expand All @@ -68,12 +60,8 @@ func validateSettings(domain, host string, settings settings) error {
return fmt.Errorf("%w", ddnserrors.ErrDomainNotSet)
case host == "":
return fmt.Errorf("%w", ddnserrors.ErrHostNotSet)
case settings.TenantID == "":
return fmt.Errorf("%w: tenant id", ddnserrors.ErrCredentialsNotSet)
case settings.ClientID == "":
return fmt.Errorf("%w: client id", ddnserrors.ErrCredentialsNotSet)
case settings.ClientSecret == "":
return fmt.Errorf("%w: client secret", ddnserrors.ErrCredentialsNotSet)
case settings.Token == "":
return fmt.Errorf("%w", ddnserrors.ErrTokenNotSet)
case settings.SubscriptionID == "":
return fmt.Errorf("%w: subscription id", ddnserrors.ErrKeyNotSet)
case settings.ResourceGroupName == "":
Expand Down Expand Up @@ -119,29 +107,20 @@ func (p *Provider) HTML() models.HTMLRow {
}
}

func ptrTo[T any](v T) *T { return &v }

func (p *Provider) Update(ctx context.Context, _ *http.Client, ip netip.Addr) (newIP netip.Addr, err error) {
var recordType armdns.RecordType
if ip.Is4() {
recordType = armdns.RecordTypeA
} else {
recordType = armdns.RecordTypeAAAA
}

client, err := p.createClient()
if err != nil {
return netip.Addr{}, fmt.Errorf("creating client: %w", err)
func (p *Provider) Update(ctx context.Context, client *http.Client, ip netip.Addr) (newIP netip.Addr, err error) {
recordType := constants.A
if ip.Is6() {
recordType = constants.AAAA
}

response, err := p.getRecordSet(ctx, client, recordType)
if err != nil {
azureErr := &azcore.ResponseError{}
if errors.As(err, &azureErr) && azureErr.StatusCode == http.StatusNotFound {
if errors.Is(err, ddnserrors.ErrRecordNotFound) {
err = p.createRecordSet(ctx, client, ip)
if err != nil {
return netip.Addr{}, fmt.Errorf("creating record set: %w", err)
}
return ip, nil
}
return netip.Addr{}, fmt.Errorf("getting record set: %w", err)
}
Expand Down

0 comments on commit 8e39ff0

Please sign in to comment.