Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resolver functions taking context argument #6

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package doh

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
Expand All @@ -11,11 +12,11 @@ import (
// request as described in RFC 8484, and returns the response's body.
// Returns an error if there was an issue sending the request or reading the
// response body.
func (r *Resolver) exchangeHTTPS(q []byte) (a []byte, err error) {
func (r *Resolver) exchangeHTTPS(ctx context.Context, q []byte) (a []byte, err error) {
url := fmt.Sprintf("https://%s/dns-query", r.Host)
body := bytes.NewBuffer(q)

req, err := http.NewRequest("POST", url, body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return
}
Expand Down
69 changes: 56 additions & 13 deletions resolver.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Package doh implements client operations for DoH (DNS over HTTPS) lookups.
package doh

import "net/http"
import (
"context"
"net/http"
)

// Resolver handles lookups.
type Resolver struct {
Expand All @@ -17,9 +20,9 @@ type Resolver struct {
// lookup encodes a DNS query, sends it over HTTPS then parses the response.
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) lookup(fqdn string, t DNSType, c DNSClass) ([]answer, error) {
func (r *Resolver) lookup(ctx context.Context, fqdn string, t DNSType, c DNSClass) ([]answer, error) {
q := encodeQuery(fqdn, t, c)
res, err := r.exchangeHTTPS(q)
res, err := r.exchangeHTTPS(ctx, q)
if err != nil {
return nil, err
}
Expand All @@ -31,12 +34,16 @@ func (r *Resolver) lookup(fqdn string, t DNSType, c DNSClass) ([]answer, error)
// Returns an error if something went wrong at the network level, or when
// parsing the response headers, or if the resolver's class isn't IN.
func (r *Resolver) LookupA(fqdn string) (recs []*ARecord, ttls []uint32, err error) {
return r.LookupAContext(context.Background(), fqdn)
}

func (r *Resolver) LookupAContext(ctx context.Context, fqdn string) (recs []*ARecord, ttls []uint32, err error) {
if r.Class != IN && r.Class != ANYCLASS {
err = ErrNotIN
return
}

answers, err := r.lookup(fqdn, A, IN)
answers, err := r.lookup(ctx, fqdn, A, IN)
if err != nil {
return
}
Expand All @@ -59,12 +66,16 @@ func (r *Resolver) LookupA(fqdn string) (recs []*ARecord, ttls []uint32, err err
// Returns an error if something went wrong at the network level, or when
// parsing the response headers, or if the resolver's class isn't IN.
func (r *Resolver) LookupAAAA(fqdn string) (recs []*AAAARecord, ttls []uint32, err error) {
return r.LookupAAAAContext(context.Background(), fqdn)
}

func (r *Resolver) LookupAAAAContext(ctx context.Context, fqdn string) (recs []*AAAARecord, ttls []uint32, err error) {
if r.Class != IN && r.Class != ANYCLASS {
err = ErrNotIN
return
}

answers, err := r.lookup(fqdn, AAAA, IN)
answers, err := r.lookup(ctx, fqdn, AAAA, IN)
if err != nil {
return
}
Expand All @@ -87,7 +98,11 @@ func (r *Resolver) LookupAAAA(fqdn string) (recs []*AAAARecord, ttls []uint32, e
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupCNAME(fqdn string) (recs []*CNAMERecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, CNAME, IN)
return r.LookupCNAMEContext(context.Background(), fqdn)
}

func (r *Resolver) LookupCNAMEContext(ctx context.Context, fqdn string) (recs []*CNAMERecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, CNAME, IN)
if err != nil {
return
}
Expand All @@ -110,7 +125,11 @@ func (r *Resolver) LookupCNAME(fqdn string) (recs []*CNAMERecord, ttls []uint32,
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupMX(fqdn string) (recs []*MXRecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, MX, IN)
return r.LookupMXContext(context.Background(), fqdn)
}

func (r *Resolver) LookupMXContext(ctx context.Context, fqdn string) (recs []*MXRecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, MX, IN)
if err != nil {
return
}
Expand All @@ -133,7 +152,11 @@ func (r *Resolver) LookupMX(fqdn string) (recs []*MXRecord, ttls []uint32, err e
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupNS(fqdn string) (recs []*NSRecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, NS, IN)
return r.LookupNSContext(context.Background(), fqdn)
}

func (r *Resolver) LookupNSContext(ctx context.Context, fqdn string) (recs []*NSRecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, NS, IN)
if err != nil {
return
}
Expand All @@ -156,7 +179,11 @@ func (r *Resolver) LookupNS(fqdn string) (recs []*NSRecord, ttls []uint32, err e
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupTXT(fqdn string) (recs []*TXTRecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, TXT, IN)
return r.LookupTXTContext(context.Background(), fqdn)
}

func (r *Resolver) LookupTXTContext(ctx context.Context, fqdn string) (recs []*TXTRecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, TXT, IN)
if err != nil {
return
}
Expand All @@ -179,7 +206,11 @@ func (r *Resolver) LookupTXT(fqdn string) (recs []*TXTRecord, ttls []uint32, err
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupSRV(fqdn string) (recs []*SRVRecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, SRV, IN)
return r.LookupSRVContext(context.Background(), fqdn)
}

func (r *Resolver) LookupSRVContext(ctx context.Context, fqdn string) (recs []*SRVRecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, SRV, IN)
if err != nil {
return
}
Expand All @@ -205,15 +236,23 @@ func (r *Resolver) LookupSRV(fqdn string) (recs []*SRVRecord, ttls []uint32, err
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupService(service, network, domain string) (recs []*SRVRecord, ttls []uint32, err error) {
return r.LookupSRV("_" + service + "._" + network + "." + domain)
return r.LookupServiceContext(context.Background(), service, network, domain)
}

func (r *Resolver) LookupServiceContext(ctx context.Context, service, network, domain string) (recs []*SRVRecord, ttls []uint32, err error) {
return r.LookupSRVContext(ctx, "_"+service+"._"+network+"."+domain)
}

// LookupSOA performs a DoH lookup on SOA records for the given FQDN.
// Returns records and TTLs such that ttls[0] is the TTL for recs[0], and so on.
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupSOA(fqdn string) (recs []*SOARecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, SOA, IN)
return r.LookupSOAContext(context.Background(), fqdn)
}

func (r *Resolver) LookupSOAContext(ctx context.Context, fqdn string) (recs []*SOARecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, SOA, IN)
if err != nil {
return
}
Expand All @@ -236,7 +275,11 @@ func (r *Resolver) LookupSOA(fqdn string) (recs []*SOARecord, ttls []uint32, err
// Returns an error if something went wrong at the network level, or when
// parsing the response headers.
func (r *Resolver) LookupPTR(fqdn string) (recs []*PTRRecord, ttls []uint32, err error) {
answers, err := r.lookup(fqdn, PTR, IN)
return r.LookupPTRContext(context.Background(), fqdn)
}

func (r *Resolver) LookupPTRContext(ctx context.Context, fqdn string) (recs []*PTRRecord, ttls []uint32, err error) {
answers, err := r.lookup(ctx, fqdn, PTR, IN)
if err != nil {
return
}
Expand Down