diff --git a/http.go b/http.go index adf9bb1..b9064a1 100644 --- a/http.go +++ b/http.go @@ -2,6 +2,7 @@ package doh import ( "bytes" + "context" "fmt" "io/ioutil" "net/http" @@ -11,11 +12,11 @@ import ( // request as described in RFC 8484, and returns the response's body. // Returns an error if there was an issue sending the request or reading the // response body. -func (r *Resolver) exchangeHTTPS(q []byte) (a []byte, err error) { +func (r *Resolver) exchangeHTTPS(ctx context.Context, q []byte) (a []byte, err error) { url := fmt.Sprintf("https://%s/dns-query", r.Host) body := bytes.NewBuffer(q) - req, err := http.NewRequest("POST", url, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) if err != nil { return } diff --git a/resolver.go b/resolver.go index 0a5d85f..f3535e6 100644 --- a/resolver.go +++ b/resolver.go @@ -1,7 +1,10 @@ // Package doh implements client operations for DoH (DNS over HTTPS) lookups. package doh -import "net/http" +import ( + "context" + "net/http" +) // Resolver handles lookups. type Resolver struct { @@ -17,9 +20,9 @@ type Resolver struct { // lookup encodes a DNS query, sends it over HTTPS then parses the response. // Returns an error if something went wrong at the network level, or when // parsing the response headers. -func (r *Resolver) lookup(fqdn string, t DNSType, c DNSClass) ([]answer, error) { +func (r *Resolver) lookup(ctx context.Context, fqdn string, t DNSType, c DNSClass) ([]answer, error) { q := encodeQuery(fqdn, t, c) - res, err := r.exchangeHTTPS(q) + res, err := r.exchangeHTTPS(ctx, q) if err != nil { return nil, err } @@ -31,12 +34,16 @@ func (r *Resolver) lookup(fqdn string, t DNSType, c DNSClass) ([]answer, error) // Returns an error if something went wrong at the network level, or when // parsing the response headers, or if the resolver's class isn't IN. func (r *Resolver) LookupA(fqdn string) (recs []*ARecord, ttls []uint32, err error) { + return r.LookupAContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupAContext(ctx context.Context, fqdn string) (recs []*ARecord, ttls []uint32, err error) { if r.Class != IN && r.Class != ANYCLASS { err = ErrNotIN return } - answers, err := r.lookup(fqdn, A, IN) + answers, err := r.lookup(ctx, fqdn, A, IN) if err != nil { return } @@ -59,12 +66,16 @@ func (r *Resolver) LookupA(fqdn string) (recs []*ARecord, ttls []uint32, err err // Returns an error if something went wrong at the network level, or when // parsing the response headers, or if the resolver's class isn't IN. func (r *Resolver) LookupAAAA(fqdn string) (recs []*AAAARecord, ttls []uint32, err error) { + return r.LookupAAAAContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupAAAAContext(ctx context.Context, fqdn string) (recs []*AAAARecord, ttls []uint32, err error) { if r.Class != IN && r.Class != ANYCLASS { err = ErrNotIN return } - answers, err := r.lookup(fqdn, AAAA, IN) + answers, err := r.lookup(ctx, fqdn, AAAA, IN) if err != nil { return } @@ -87,7 +98,11 @@ func (r *Resolver) LookupAAAA(fqdn string) (recs []*AAAARecord, ttls []uint32, e // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupCNAME(fqdn string) (recs []*CNAMERecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, CNAME, IN) + return r.LookupCNAMEContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupCNAMEContext(ctx context.Context, fqdn string) (recs []*CNAMERecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, CNAME, IN) if err != nil { return } @@ -110,7 +125,11 @@ func (r *Resolver) LookupCNAME(fqdn string) (recs []*CNAMERecord, ttls []uint32, // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupMX(fqdn string) (recs []*MXRecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, MX, IN) + return r.LookupMXContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupMXContext(ctx context.Context, fqdn string) (recs []*MXRecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, MX, IN) if err != nil { return } @@ -133,7 +152,11 @@ func (r *Resolver) LookupMX(fqdn string) (recs []*MXRecord, ttls []uint32, err e // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupNS(fqdn string) (recs []*NSRecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, NS, IN) + return r.LookupNSContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupNSContext(ctx context.Context, fqdn string) (recs []*NSRecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, NS, IN) if err != nil { return } @@ -156,7 +179,11 @@ func (r *Resolver) LookupNS(fqdn string) (recs []*NSRecord, ttls []uint32, err e // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupTXT(fqdn string) (recs []*TXTRecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, TXT, IN) + return r.LookupTXTContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupTXTContext(ctx context.Context, fqdn string) (recs []*TXTRecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, TXT, IN) if err != nil { return } @@ -179,7 +206,11 @@ func (r *Resolver) LookupTXT(fqdn string) (recs []*TXTRecord, ttls []uint32, err // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupSRV(fqdn string) (recs []*SRVRecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, SRV, IN) + return r.LookupSRVContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupSRVContext(ctx context.Context, fqdn string) (recs []*SRVRecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, SRV, IN) if err != nil { return } @@ -205,7 +236,11 @@ func (r *Resolver) LookupSRV(fqdn string) (recs []*SRVRecord, ttls []uint32, err // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupService(service, network, domain string) (recs []*SRVRecord, ttls []uint32, err error) { - return r.LookupSRV("_" + service + "._" + network + "." + domain) + return r.LookupServiceContext(context.Background(), service, network, domain) +} + +func (r *Resolver) LookupServiceContext(ctx context.Context, service, network, domain string) (recs []*SRVRecord, ttls []uint32, err error) { + return r.LookupSRVContext(ctx, "_"+service+"._"+network+"."+domain) } // LookupSOA performs a DoH lookup on SOA records for the given FQDN. @@ -213,7 +248,11 @@ func (r *Resolver) LookupService(service, network, domain string) (recs []*SRVRe // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupSOA(fqdn string) (recs []*SOARecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, SOA, IN) + return r.LookupSOAContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupSOAContext(ctx context.Context, fqdn string) (recs []*SOARecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, SOA, IN) if err != nil { return } @@ -236,7 +275,11 @@ func (r *Resolver) LookupSOA(fqdn string) (recs []*SOARecord, ttls []uint32, err // Returns an error if something went wrong at the network level, or when // parsing the response headers. func (r *Resolver) LookupPTR(fqdn string) (recs []*PTRRecord, ttls []uint32, err error) { - answers, err := r.lookup(fqdn, PTR, IN) + return r.LookupPTRContext(context.Background(), fqdn) +} + +func (r *Resolver) LookupPTRContext(ctx context.Context, fqdn string) (recs []*PTRRecord, ttls []uint32, err error) { + answers, err := r.lookup(ctx, fqdn, PTR, IN) if err != nil { return }