From 2e9745f2221df3e02c38b2811001856107eb9289 Mon Sep 17 00:00:00 2001 From: phuslu Date: Wed, 23 Oct 2024 15:40:19 +0800 Subject: [PATCH] switch to range func --- .github/workflows/benchmark.yml | 4 +- .github/workflows/build.yml | 6 +-- client_resolver.go | 69 ++++++++++++++------------------- client_test.go | 13 +++---- cmd/fastdig/fastdig.go | 57 +++++++++++++-------------- cmd/fastdoh/go.mod | 2 +- cmd/fastdoh/main.go | 13 +++---- go.mod | 2 +- message.go | 20 ++++++---- 9 files changed, 90 insertions(+), 96 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4e49b8f..a5dde2f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -11,9 +11,9 @@ jobs: benchmark: runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v4 with: - go-version: '1.22.1' + go-version: '1.23.2' - name: Benchmark run: | set -ex diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 963894f..d92c355 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,9 +9,9 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v4 with: - go-version: '1.22.1' + go-version: '1.23.2' - name: Build run: | set -ex @@ -21,5 +21,5 @@ jobs: go test -v -cover go build -v -race (cd cmd/fastdig && go build -v) - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.56.2 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.61.0 ./bin/golangci-lint run diff --git a/client_resolver.go b/client_resolver.go index b6420cf..ae041ac 100644 --- a/client_resolver.go +++ b/client_resolver.go @@ -40,18 +40,16 @@ func (c *Client) AppendLookupNetIP(dst []netip.Addr, ctx context.Context, networ } cname := make([]byte, 0, 64) - - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case TypeCNAME: - cname = resp.DecodeName(cname[:0], data) + cname = resp.DecodeName(cname[:0], r.Data) case TypeA: - dst = append(dst, netip.AddrFrom4(*(*[4]byte)(data))) + dst = append(dst, netip.AddrFrom4(*(*[4]byte)(r.Data))) case TypeAAAA: - dst = append(dst, netip.AddrFrom16(*(*[16]byte)(data))) + dst = append(dst, netip.AddrFrom16(*(*[16]byte)(r.Data))) } - return true - }) + } if len(cname) != 0 && len(dst) == 0 { dst, err = c.AppendLookupNetIP(dst, ctx, network, b2s(cname)) @@ -78,16 +76,14 @@ func (c *Client) LookupCNAME(ctx context.Context, host string) (cname string, er return } - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case TypeCNAME: - cname = string(resp.DecodeName(nil, data)) - return false + cname = string(resp.DecodeName(nil, r.Data)) default: err = ErrInvalidAnswer } - return true - }) + } return } @@ -107,17 +103,16 @@ func (c *Client) LookupNS(ctx context.Context, name string) (ns []*net.NS, err e soa := make([]byte, 0, 64) - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case TypeSOA: - soa = resp.DecodeName(soa[:0], name) + soa = resp.DecodeName(soa[:0], r.Name) case TypeNS: - ns = append(ns, &net.NS{Host: string(resp.DecodeName(nil, data))}) + ns = append(ns, &net.NS{Host: string(resp.DecodeName(nil, r.Data))}) default: err = ErrInvalidAnswer } - return true - }) + } if len(soa) != 0 { ns, err = c.LookupNS(ctx, b2s(soa)) @@ -139,19 +134,18 @@ func (c *Client) LookupTXT(ctx context.Context, host string) (txt []string, err return } - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case TypeTXT: - if len(data) > 1 && int(data[0])+1 == len(data) { - txt = append(txt, string(data[1:])) + if len(r.Data) > 1 && int(r.Data[0])+1 == len(r.Data) { + txt = append(txt, string(r.Data[1:])) } else { err = ErrInvalidAnswer } default: err = ErrInvalidAnswer } - return true - }) + } return } @@ -169,16 +163,14 @@ func (c *Client) LookupMX(ctx context.Context, host string) (mx []*net.MX, err e return } - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { - case TypeMX: + for r := range resp.Records { + if r.Type == TypeMX { mx = append(mx, &net.MX{ - Host: string(resp.DecodeName(nil, data[2:])), - Pref: binary.BigEndian.Uint16(data), + Host: string(resp.DecodeName(nil, r.Data[2:])), + Pref: binary.BigEndian.Uint16(r.Data), }) } - return true - }) + } return } @@ -196,12 +188,12 @@ func (c *Client) LookupHTTPS(ctx context.Context, host string) (https []NetHTTPS return } - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { - case TypeHTTPS: + for r := range resp.Records { + if r.Type == TypeHTTPS { var h NetHTTPS + data := r.Data if len(data) < 7 { - return true + return nil, ErrInvalidAnswer } data = data[3:] for len(data) >= 4 { @@ -243,8 +235,7 @@ func (c *Client) LookupHTTPS(ctx context.Context, host string) (https []NetHTTPS } https = append(https, h) } - return true - }) + } return } diff --git a/client_test.go b/client_test.go index 280d923..c5a31f2 100644 --- a/client_test.go +++ b/client_test.go @@ -51,17 +51,16 @@ func TestClientExchange(t *testing.T) { t.Errorf("client=%+v exchange(%v) error: %+v\n", client, c.Domain, err) } t.Logf("%s: CLASS %s TYPE %s\n", resp.Domain, resp.Question.Class, resp.Question.Type) - _ = resp.Walk(func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case TypeCNAME: - t.Logf("%s.\t%d\t%s\t%s\t%s.\n", resp.DecodeName(nil, name), ttl, class, typ, resp.DecodeName(nil, data)) + t.Logf("%s.\t%d\t%s\t%s\t%s.\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, resp.DecodeName(nil, r.Data)) case TypeA: - t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, name), ttl, class, typ, netip.AddrFrom4(*(*[4]byte)(data))) + t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, netip.AddrFrom4(*(*[4]byte)(r.Data))) case TypeAAAA: - t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, name), ttl, class, typ, netip.AddrFrom16(*(*[16]byte)(data))) + t.Logf("%s.\t%d\t%s\t%s\t%s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, netip.AddrFrom16(*(*[16]byte)(r.Data))) } - return true - }) + } } } diff --git a/cmd/fastdig/fastdig.go b/cmd/fastdig/fastdig.go index 368cc2c..b596301 100644 --- a/cmd/fastdig/fastdig.go +++ b/cmd/fastdig/fastdig.go @@ -103,48 +103,47 @@ func opt(option string, options []string) bool { } func short(resp *fastdns.Message) { - _ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool { + for r := range resp.Records { var v interface{} - switch typ { + switch r.Type { case fastdns.TypeA, fastdns.TypeAAAA: - v, _ = netip.AddrFromSlice(data) + v, _ = netip.AddrFromSlice(r.Data) case fastdns.TypeCNAME, fastdns.TypeNS: - v = fmt.Sprintf("%s.", resp.DecodeName(nil, data)) + v = fmt.Sprintf("%s.", resp.DecodeName(nil, r.Data)) case fastdns.TypeMX: - v = fmt.Sprintf("%d %s.", binary.BigEndian.Uint16(data), resp.DecodeName(nil, data[2:])) + v = fmt.Sprintf("%d %s.", binary.BigEndian.Uint16(r.Data), resp.DecodeName(nil, r.Data[2:])) case fastdns.TypeTXT: - v = fmt.Sprintf("\"%s\"", data[1:]) + v = fmt.Sprintf("\"%s\"", r.Data[1:]) case fastdns.TypeSRV: - priority := binary.BigEndian.Uint16(data) - weight := binary.BigEndian.Uint16(data[2:]) - port := binary.BigEndian.Uint16(data[4:]) - target := resp.DecodeName(nil, data[6:]) + priority := binary.BigEndian.Uint16(r.Data) + weight := binary.BigEndian.Uint16(r.Data[2:]) + port := binary.BigEndian.Uint16(r.Data[4:]) + target := resp.DecodeName(nil, r.Data[6:]) v = fmt.Sprintf("%d %d %d %s.", priority, weight, port, target) case fastdns.TypeSOA: var mname []byte - for i, b := range data { + for i, b := range r.Data { if b == 0 { - mname = data[:i+1] + mname = r.Data[:i+1] break } else if b&0b11000000 == 0b11000000 { - mname = data[:i+2] + mname = r.Data[:i+2] break } } - nname := resp.DecodeName(nil, data[len(mname):len(data)-20]) + nname := resp.DecodeName(nil, r.Data[len(mname):len(r.Data)-20]) mname = resp.DecodeName(nil, mname) - serial := binary.BigEndian.Uint32(data[len(data)-20:]) - refresh := binary.BigEndian.Uint32(data[len(data)-16:]) - retry := binary.BigEndian.Uint32(data[len(data)-12:]) - expire := binary.BigEndian.Uint32(data[len(data)-8:]) - minimum := binary.BigEndian.Uint32(data[len(data)-4:]) + serial := binary.BigEndian.Uint32(r.Data[len(r.Data)-20:]) + refresh := binary.BigEndian.Uint32(r.Data[len(r.Data)-16:]) + retry := binary.BigEndian.Uint32(r.Data[len(r.Data)-12:]) + expire := binary.BigEndian.Uint32(r.Data[len(r.Data)-8:]) + minimum := binary.BigEndian.Uint32(r.Data[len(r.Data)-4:]) v = fmt.Sprintf("%s. %s. %d %d %d %d %d", mname, nname, serial, refresh, retry, expire, minimum) default: - v = fmt.Sprintf("%x", data) + v = fmt.Sprintf("%x", r.Data) } fmt.Printf("%s\n", v) - return true - }) + } } func cmd(req, resp *fastdns.Message, server string, start, end time.Time) { @@ -187,11 +186,12 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) { } else { fmt.Printf(";; AUTHORITY SECTION:\n") } - var index int - _ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool { + index := 0 + for r := range resp.Records { index++ + data := r.Data var v interface{} - switch typ { + switch r.Type { case fastdns.TypeA, fastdns.TypeAAAA: v, _ = netip.AddrFromSlice(data) case fastdns.TypeCNAME, fastdns.TypeNS: @@ -228,7 +228,7 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) { case fastdns.TypeHTTPS: var h fastdns.NetHTTPS if len(data) < 7 { - return true + return } data = data[3:] for len(data) >= 4 { @@ -304,9 +304,8 @@ func cmd(req, resp *fastdns.Message, server string, start, end time.Time) { default: v = fmt.Sprintf("%x", data) } - fmt.Printf("%s. %d %s %s %s\n", resp.DecodeName(nil, name), ttl, class, typ, v) - return true - }) + fmt.Printf("%s. %d %s %s %s\n", resp.DecodeName(nil, r.Name), r.TTL, r.Class, r.Type, v) + } fmt.Printf("\n") fmt.Printf(";; Query time: %d msec\n", end.Sub(start)/time.Millisecond) diff --git a/cmd/fastdoh/go.mod b/cmd/fastdoh/go.mod index 04c8fec..14d11d2 100644 --- a/cmd/fastdoh/go.mod +++ b/cmd/fastdoh/go.mod @@ -1,6 +1,6 @@ module main -go 1.22 +go 1.23 require ( github.com/phuslu/fastdns v1.0.0 diff --git a/cmd/fastdoh/main.go b/cmd/fastdoh/main.go index 93a678f..9a1138d 100644 --- a/cmd/fastdoh/main.go +++ b/cmd/fastdoh/main.go @@ -37,17 +37,16 @@ func (h *DNSHandler) ServeDNS(rw fastdns.ResponseWriter, req *fastdns.Message) { } if h.Debug { - _ = resp.Walk(func(name []byte, typ fastdns.Type, class fastdns.Class, ttl uint32, data []byte) bool { - switch typ { + for r := range resp.Records { + switch r.Type { case fastdns.TypeCNAME: - slog.Info("dns request CNAME", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "CNAME", resp.DecodeName(nil, data)) + slog.Info("dns request CNAME", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "CNAME", resp.DecodeName(nil, r.Data)) case fastdns.TypeA: - slog.Info("dns request A", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "A", netip.AddrFrom4(*(*[4]byte)(data))) + slog.Info("dns request A", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "A", netip.AddrFrom4(*(*[4]byte)(r.Data))) case fastdns.TypeAAAA: - slog.Info("dns request AAAA", "name", resp.DecodeName(nil, name), "ttl", ttl, "class", class, "type", typ, "AAAA", netip.AddrFrom16(*(*[16]byte)(data))) + slog.Info("dns request AAAA", "name", resp.DecodeName(nil, r.Name), "ttl", r.TTL, "class", r.Class, "type", r.Type, "AAAA", netip.AddrFrom16(*(*[16]byte)(r.Data))) } - return true - }) + } slog.Info("serve dns answers", "remote_addr", rw.RemoteAddr(), "domain", req.Domain, "remote_addr", h.DNSClient.Addr, "answer_count", resp.Header.ANCount) } diff --git a/go.mod b/go.mod index 84915e4..6057429 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/phuslu/fastdns -go 1.22 \ No newline at end of file +go 1.23 \ No newline at end of file diff --git a/message.go b/message.go index d67d4f2..8270c0e 100644 --- a/message.go +++ b/message.go @@ -215,11 +215,19 @@ func (msg *Message) DecodeName(dst []byte, name []byte) []byte { return dst } +type AnswerRecord struct { + Name []byte + Type Type + Class Class + TTL uint32 + Data []byte +} + // Walk calls f for each item in the msg in the original order of the parsed RR. -func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error { +func (msg *Message) Records(f func(AnswerRecord) bool) { n := msg.Header.ANCount + msg.Header.NSCount if n == 0 { - return ErrInvalidAnswer + return } payload := msg.Raw[16+len(msg.Question.Name):] @@ -238,7 +246,7 @@ func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32, } } if name == nil { - return ErrInvalidAnswer + return } _ = payload[9] // hint compiler to remove bounds check typ := Type(payload[0])<<8 | Type(payload[1]) @@ -247,17 +255,15 @@ func (msg *Message) Walk(f func(name []byte, typ Type, class Class, ttl uint32, length := uint16(payload[8])<<8 | uint16(payload[9]) data := payload[10 : 10+length] payload = payload[10+length:] - ok := f(name, typ, class, ttl, data) + ok := f(AnswerRecord{Name: name, Type: typ, Class: class, TTL: ttl, Data: data}) if !ok { break } } - - return nil } // WalkAdditionalRecords calls f for each item in the msg in the original order of the parsed AR. -func (msg *Message) WalkAdditionalRecords(f func(name []byte, typ Type, class Class, ttl uint32, data []byte) bool) error { +func (msg *Message) AdditionalRecords(f func(AnswerRecord) bool) { panic("not implemented") }