From a205fafc4f9cb4f96f5350b4d5c3b5cf54f62cfe Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Tue, 7 Nov 2023 16:09:53 +1030 Subject: [PATCH] Use cryptobyte.String to unpack EDNS0 and SVCB values This unpacking is more strict than previous. It now validates that no trailing data exists. --- edns.go | 152 ++++++++++++++++++++++++++++--------------------- edns_test.go | 13 +++-- msg_helpers.go | 40 +++++++------ svcb.go | 120 +++++++++++++++++++------------------- svcb_test.go | 13 +++-- 5 files changed, 185 insertions(+), 153 deletions(-) diff --git a/edns.go b/edns.go index 1b58e8f0a..eebbc955a 100644 --- a/edns.go +++ b/edns.go @@ -7,6 +7,8 @@ import ( "fmt" "net" "strconv" + + "golang.org/x/crypto/cryptobyte" ) // EDNS0 Option codes. @@ -214,9 +216,8 @@ type EDNS0 interface { Option() uint16 // pack returns the bytes of the option data. pack() ([]byte, error) - // unpack sets the data as found in the buffer. Is also sets - // the length of the slice as the length of the option data. - unpack([]byte) error + // unpack sets the data as found in the buffer. + unpack(*cryptobyte.String) error // String returns the string representation of the option. String() string // copy returns a deep-copy of the option. @@ -248,11 +249,16 @@ func (e *EDNS0_NSID) pack() ([]byte, error) { return h, nil } +func (e *EDNS0_NSID) unpack(b *cryptobyte.String) error { + e.Nsid = hex.EncodeToString(*b) + b.Skip(len(*b)) + return nil +} + // Option implements the EDNS0 interface. -func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code. -func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil } -func (e *EDNS0_NSID) String() string { return e.Nsid } -func (e *EDNS0_NSID) copy() EDNS0 { return &EDNS0_NSID{e.Code, e.Nsid} } +func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code. +func (e *EDNS0_NSID) String() string { return e.Nsid } +func (e *EDNS0_NSID) copy() EDNS0 { return &EDNS0_NSID{e.Code, e.Nsid} } // EDNS0_SUBNET is the subnet option that is used to give the remote nameserver // an idea of where the client lives. See RFC 7871. It can then give back a different @@ -323,13 +329,12 @@ func (e *EDNS0_SUBNET) pack() ([]byte, error) { return b, nil } -func (e *EDNS0_SUBNET) unpack(b []byte) error { - if len(b) < 4 { - return ErrBuf +func (e *EDNS0_SUBNET) unpack(b *cryptobyte.String) error { + if !b.ReadUint16(&e.Family) || + !b.ReadUint8(&e.SourceNetmask) || + !b.ReadUint8(&e.SourceScope) { + return errUnpackOverflow } - e.Family = binary.BigEndian.Uint16(b) - e.SourceNetmask = b[2] - e.SourceScope = b[3] switch e.Family { case 0: // "dig" sets AddressFamily to 0 if SourceNetmask is also 0 @@ -343,14 +348,14 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { return errors.New("dns: bad netmask") } addr := make(net.IP, net.IPv4len) - copy(addr, b[4:]) + b.Skip(copy(addr, *b)) e.Address = addr.To16() case 2: if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 { return errors.New("dns: bad netmask") } addr := make(net.IP, net.IPv6len) - copy(addr, b[4:]) + b.Skip(copy(addr, *b)) e.Address = addr default: return errors.New("dns: bad address family") @@ -411,11 +416,16 @@ func (e *EDNS0_COOKIE) pack() ([]byte, error) { return h, nil } +func (e *EDNS0_COOKIE) unpack(b *cryptobyte.String) error { + e.Cookie = hex.EncodeToString(*b) + b.Skip(len(*b)) + return nil +} + // Option implements the EDNS0 interface. -func (e *EDNS0_COOKIE) Option() uint16 { return EDNS0COOKIE } -func (e *EDNS0_COOKIE) unpack(b []byte) error { e.Cookie = hex.EncodeToString(b); return nil } -func (e *EDNS0_COOKIE) String() string { return e.Cookie } -func (e *EDNS0_COOKIE) copy() EDNS0 { return &EDNS0_COOKIE{e.Code, e.Cookie} } +func (e *EDNS0_COOKIE) Option() uint16 { return EDNS0COOKIE } +func (e *EDNS0_COOKIE) String() string { return e.Cookie } +func (e *EDNS0_COOKIE) copy() EDNS0 { return &EDNS0_COOKIE{e.Code, e.Cookie} } // The EDNS0_UL (Update Lease) (draft RFC) option is used to tell the server to set // an expiration on an update RR. This is helpful for clients that cannot clean @@ -453,16 +463,13 @@ func (e *EDNS0_UL) pack() ([]byte, error) { return b, nil } -func (e *EDNS0_UL) unpack(b []byte) error { - switch len(b) { - case 4: - e.KeyLease = 0 - case 8: - e.KeyLease = binary.BigEndian.Uint32(b[4:]) - default: - return ErrBuf +func (e *EDNS0_UL) unpack(b *cryptobyte.String) error { + if !b.ReadUint32(&e.Lease) { + return errUnpackOverflow + } + if !b.Empty() && !b.ReadUint32(&e.KeyLease) { + return errUnpackOverflow } - e.Lease = binary.BigEndian.Uint32(b) return nil } @@ -490,15 +497,14 @@ func (e *EDNS0_LLQ) pack() ([]byte, error) { return b, nil } -func (e *EDNS0_LLQ) unpack(b []byte) error { - if len(b) < 18 { - return ErrBuf +func (e *EDNS0_LLQ) unpack(b *cryptobyte.String) error { + if !b.ReadUint16(&e.Version) || + !b.ReadUint16(&e.Opcode) || + !b.ReadUint16(&e.Error) || + !b.ReadUint64(&e.Id) || + !b.ReadUint32(&e.LeaseLife) { + return errUnpackOverflow } - e.Version = binary.BigEndian.Uint16(b[0:]) - e.Opcode = binary.BigEndian.Uint16(b[2:]) - e.Error = binary.BigEndian.Uint16(b[4:]) - e.Id = binary.BigEndian.Uint64(b[6:]) - e.LeaseLife = binary.BigEndian.Uint32(b[14:]) return nil } @@ -522,7 +528,12 @@ type EDNS0_DAU struct { // Option implements the EDNS0 interface. func (e *EDNS0_DAU) Option() uint16 { return EDNS0DAU } func (e *EDNS0_DAU) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil } -func (e *EDNS0_DAU) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil } + +func (e *EDNS0_DAU) unpack(b *cryptobyte.String) error { + e.AlgCode = cloneSlice(*b) + b.Skip(len(*b)) + return nil +} func (e *EDNS0_DAU) String() string { s := "" @@ -546,7 +557,12 @@ type EDNS0_DHU struct { // Option implements the EDNS0 interface. func (e *EDNS0_DHU) Option() uint16 { return EDNS0DHU } func (e *EDNS0_DHU) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil } -func (e *EDNS0_DHU) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil } + +func (e *EDNS0_DHU) unpack(b *cryptobyte.String) error { + e.AlgCode = cloneSlice(*b) + b.Skip(len(*b)) + return nil +} func (e *EDNS0_DHU) String() string { s := "" @@ -570,7 +586,12 @@ type EDNS0_N3U struct { // Option implements the EDNS0 interface. func (e *EDNS0_N3U) Option() uint16 { return EDNS0N3U } func (e *EDNS0_N3U) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil } -func (e *EDNS0_N3U) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil } + +func (e *EDNS0_N3U) unpack(b *cryptobyte.String) error { + e.AlgCode = cloneSlice(*b) + b.Skip(len(*b)) + return nil +} func (e *EDNS0_N3U) String() string { // Re-use the hash map @@ -606,17 +627,12 @@ func (e *EDNS0_EXPIRE) pack() ([]byte, error) { return b, nil } -func (e *EDNS0_EXPIRE) unpack(b []byte) error { - if len(b) == 0 { - // zero-length EXPIRE query, see RFC 7314 Section 2 - e.Empty = true - return nil - } - if len(b) < 4 { - return ErrBuf +func (e *EDNS0_EXPIRE) unpack(b *cryptobyte.String) error { + // zero-length EXPIRE query, see RFC 7314 Section 2 + e.Empty = b.Empty() + if !b.Empty() && !b.ReadUint32(&e.Expire) { + return errUnpackOverflow } - e.Expire = binary.BigEndian.Uint32(b) - e.Empty = false return nil } @@ -660,8 +676,9 @@ func (e *EDNS0_LOCAL) pack() ([]byte, error) { return cloneSlice(e.Data), nil } -func (e *EDNS0_LOCAL) unpack(b []byte) error { - e.Data = cloneSlice(b) +func (e *EDNS0_LOCAL) unpack(b *cryptobyte.String) error { + e.Data = cloneSlice(*b) + b.Skip(len(*b)) return nil } @@ -692,13 +709,9 @@ func (e *EDNS0_TCP_KEEPALIVE) pack() ([]byte, error) { return nil, nil } -func (e *EDNS0_TCP_KEEPALIVE) unpack(b []byte) error { - switch len(b) { - case 0: - case 2: - e.Timeout = binary.BigEndian.Uint16(b) - default: - return fmt.Errorf("dns: length mismatch, want 0/2 but got %d", len(b)) +func (e *EDNS0_TCP_KEEPALIVE) unpack(b *cryptobyte.String) error { + if !b.Empty() && !b.ReadUint16(&e.Timeout) { + return errUnpackOverflow } return nil } @@ -725,10 +738,15 @@ type EDNS0_PADDING struct { // Option implements the EDNS0 interface. func (e *EDNS0_PADDING) Option() uint16 { return EDNS0PADDING } func (e *EDNS0_PADDING) pack() ([]byte, error) { return cloneSlice(e.Padding), nil } -func (e *EDNS0_PADDING) unpack(b []byte) error { e.Padding = cloneSlice(b); return nil } func (e *EDNS0_PADDING) String() string { return fmt.Sprintf("%0X", e.Padding) } func (e *EDNS0_PADDING) copy() EDNS0 { return &EDNS0_PADDING{cloneSlice(e.Padding)} } +func (e *EDNS0_PADDING) unpack(b *cryptobyte.String) error { + e.Padding = cloneSlice(*b) + b.Skip(len(*b)) + return nil +} + // Extended DNS Error Codes (RFC 8914). const ( ExtendedErrorCodeOther uint16 = iota @@ -818,12 +836,12 @@ func (e *EDNS0_EDE) pack() ([]byte, error) { return b, nil } -func (e *EDNS0_EDE) unpack(b []byte) error { - if len(b) < 2 { - return ErrBuf +func (e *EDNS0_EDE) unpack(b *cryptobyte.String) error { + if !b.ReadUint16(&e.InfoCode) { + return errUnpackOverflow } - e.InfoCode = binary.BigEndian.Uint16(b[0:]) - e.ExtraText = string(b[2:]) + e.ExtraText = string(*b) + b.Skip(len(*b)) return nil } @@ -838,7 +856,9 @@ func (e *EDNS0_ESU) Option() uint16 { return EDNS0ESU } func (e *EDNS0_ESU) String() string { return e.Uri } func (e *EDNS0_ESU) copy() EDNS0 { return &EDNS0_ESU{e.Code, e.Uri} } func (e *EDNS0_ESU) pack() ([]byte, error) { return []byte(e.Uri), nil } -func (e *EDNS0_ESU) unpack(b []byte) error { - e.Uri = string(b) + +func (e *EDNS0_ESU) unpack(b *cryptobyte.String) error { + e.Uri = string(*b) + b.Skip(len(*b)) return nil } diff --git a/edns_test.go b/edns_test.go index b7c15f7e4..c3b73ad8a 100644 --- a/edns_test.go +++ b/edns_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net" "testing" + + "golang.org/x/crypto/cryptobyte" ) func TestOPTTtl(t *testing.T) { @@ -101,7 +103,7 @@ func TestEDNS0_SUBNETUnpack(t *testing.T) { } var s2 EDNS0_SUBNET - if err := s2.unpack(b); err != nil { + if err := s2.unpack((*cryptobyte.String)(&b)); err != nil { t.Fatalf("failed to unpack: %v", err) } @@ -125,8 +127,8 @@ func TestEDNS0_UL(t *testing.T) { if err != nil { t.Fatalf("failed to pack: %v", err) } - actual := EDNS0_UL{EDNS0UL, ^uint32(0), ^uint32(0)} - if err := actual.unpack(b); err != nil { + actual := EDNS0_UL{EDNS0UL, 0, 0} + if err := actual.unpack((*cryptobyte.String)(&b)); err != nil { t.Fatalf("failed to unpack: %v", err) } if expect != actual { @@ -213,15 +215,16 @@ func TestEDNS0_TCP_KEEPALIVE_unpack(t *testing.T) { }, { name: "invalid", - b: []byte{0, 1, 3}, + b: []byte{1}, expectedErr: true, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + b := tc.b e := &EDNS0_TCP_KEEPALIVE{} - err := e.unpack(tc.b) + err := e.unpack((*cryptobyte.String)(&b)) if err != nil && !tc.expectedErr { t.Error("failed to unpack, expected no error") } diff --git a/msg_helpers.go b/msg_helpers.go index 8e0762a85..78efd3edc 100644 --- a/msg_helpers.go +++ b/msg_helpers.go @@ -334,20 +334,23 @@ func packStringTxt(s []string, msg []byte, off int) (int, error) { } func unpackDataOpt(msg *cryptobyte.String) ([]EDNS0, error) { - var edns []EDNS0 + var ( + edns []EDNS0 + data cryptobyte.String // Move this out of the loop as it escapes into unpack. + ) for !msg.Empty() { - var ( - code uint16 - optData cryptobyte.String - ) + var code uint16 if !msg.ReadUint16(&code) || - !msg.ReadUint16LengthPrefixed(&optData) { + !msg.ReadUint16LengthPrefixed(&data) { return nil, &Error{err: "overflow unpacking opt"} } opt := makeDataOpt(code) - if err := opt.unpack(optData); err != nil { + if err := opt.unpack(&data); err != nil { return nil, err } + if !data.Empty() { + return nil, &Error{err: "trailing data after opt"} + } edns = append(edns, opt) } return edns, nil @@ -489,29 +492,32 @@ func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) { } func unpackDataSVCB(msg *cryptobyte.String) ([]SVCBKeyValue, error) { - var xs []SVCBKeyValue + var ( + svcb []SVCBKeyValue + data cryptobyte.String // Move this out of the loop as it escapes into unpack. + ) for !msg.Empty() { - var ( - code uint16 - kvData cryptobyte.String - ) + var code uint16 if !msg.ReadUint16(&code) || - !msg.ReadUint16LengthPrefixed(&kvData) { + !msg.ReadUint16LengthPrefixed(&data) { return nil, &Error{err: "overflow unpacking SVCB"} } kv := makeSVCBKeyValue(SVCBKey(code)) if kv == nil { return nil, &Error{err: "bad SVCB key"} } - if err := kv.unpack(kvData); err != nil { + if err := kv.unpack(&data); err != nil { return nil, err } - if len(xs) > 0 && kv.Key() <= xs[len(xs)-1].Key() { + if !data.Empty() { + return nil, &Error{err: "trailing data after SVCB key-value"} + } + if len(svcb) > 0 && kv.Key() <= svcb[len(svcb)-1].Key() { return nil, &Error{err: "SVCB keys not in strictly increasing order"} } - xs = append(xs, kv) + svcb = append(svcb, kv) } - return xs, nil + return svcb, nil } func packDataSVCB(pairs []SVCBKeyValue, msg []byte, off int) (int, error) { diff --git a/svcb.go b/svcb.go index d38aa2f05..38bbb8203 100644 --- a/svcb.go +++ b/svcb.go @@ -9,6 +9,8 @@ import ( "sort" "strconv" "strings" + + "golang.org/x/crypto/cryptobyte" ) // SVCBKey is the type of the keys used in the SVCB RR. @@ -243,13 +245,13 @@ func (rr *HTTPS) parse(c *zlexer, o string) *ParseError { // SVCBKeyValue defines a key=value pair for the SVCB RR type. // An SVCB RR can have multiple SVCBKeyValues appended to it. type SVCBKeyValue interface { - Key() SVCBKey // Key returns the numerical key code. - pack() ([]byte, error) // pack returns the encoded value. - unpack([]byte) error // unpack sets the value. - String() string // String returns the string representation of the value. - parse(string) error // parse sets the value to the given string representation of the value. - copy() SVCBKeyValue // copy returns a deep-copy of the pair. - len() int // len returns the length of value in the wire format. + Key() SVCBKey // Key returns the numerical key code. + pack() ([]byte, error) // pack returns the encoded value. + unpack(*cryptobyte.String) error // unpack sets the value. + String() string // String returns the string representation of the value. + parse(string) error // parse sets the value to the given string representation of the value. + copy() SVCBKeyValue // copy returns a deep-copy of the pair. + len() int // len returns the length of value in the wire format. } // SVCBMandatory pair adds to required keys that must be interpreted for the RR @@ -300,14 +302,15 @@ func (s *SVCBMandatory) pack() ([]byte, error) { return b, nil } -func (s *SVCBMandatory) unpack(b []byte) error { - if len(b)%2 != 0 { - return errors.New("dns: svcbmandatory: value length is not a multiple of 2") - } - codes := make([]SVCBKey, 0, len(b)/2) - for i := 0; i < len(b); i += 2 { +func (s *SVCBMandatory) unpack(b *cryptobyte.String) error { + codes := make([]SVCBKey, 0, len(*b)/2) + for !b.Empty() { // We assume strictly increasing order. - codes = append(codes, SVCBKey(binary.BigEndian.Uint16(b[i:]))) + var code SVCBKey + if !b.ReadUint16((*uint16)(&code)) { + return errUnpackOverflow + } + codes = append(codes, code) } s.Code = codes return nil @@ -410,17 +413,14 @@ func (s *SVCBAlpn) pack() ([]byte, error) { return b, nil } -func (s *SVCBAlpn) unpack(b []byte) error { - // Estimate the size of the smallest alpn as 4 bytes - alpn := make([]string, 0, len(b)/4) - for i := 0; i < len(b); { - length := int(b[i]) - i++ - if i+length > len(b) { - return errors.New("dns: svcbalpn: alpn array overflowing") +func (s *SVCBAlpn) unpack(b *cryptobyte.String) error { + var alpn []string + for !b.Empty() { + var data cryptobyte.String + if !b.ReadUint8LengthPrefixed(&data) { + return errUnpackOverflow } - alpn = append(alpn, string(b[i:i+length])) - i += length + alpn = append(alpn, string(data)) } s.Alpn = alpn return nil @@ -495,18 +495,12 @@ func (s *SVCBAlpn) copy() SVCBKeyValue { // s.Value = append(s.Value, e) type SVCBNoDefaultAlpn struct{} -func (*SVCBNoDefaultAlpn) Key() SVCBKey { return SVCB_NO_DEFAULT_ALPN } -func (*SVCBNoDefaultAlpn) copy() SVCBKeyValue { return &SVCBNoDefaultAlpn{} } -func (*SVCBNoDefaultAlpn) pack() ([]byte, error) { return []byte{}, nil } -func (*SVCBNoDefaultAlpn) String() string { return "" } -func (*SVCBNoDefaultAlpn) len() int { return 0 } - -func (*SVCBNoDefaultAlpn) unpack(b []byte) error { - if len(b) != 0 { - return errors.New("dns: svcbnodefaultalpn: no-default-alpn must have no value") - } - return nil -} +func (*SVCBNoDefaultAlpn) Key() SVCBKey { return SVCB_NO_DEFAULT_ALPN } +func (*SVCBNoDefaultAlpn) copy() SVCBKeyValue { return &SVCBNoDefaultAlpn{} } +func (*SVCBNoDefaultAlpn) pack() ([]byte, error) { return []byte{}, nil } +func (*SVCBNoDefaultAlpn) unpack(*cryptobyte.String) error { return nil } +func (*SVCBNoDefaultAlpn) String() string { return "" } +func (*SVCBNoDefaultAlpn) len() int { return 0 } func (*SVCBNoDefaultAlpn) parse(b string) error { if b != "" { @@ -531,11 +525,10 @@ func (*SVCBPort) len() int { return 2 } func (s *SVCBPort) String() string { return strconv.FormatUint(uint64(s.Port), 10) } func (s *SVCBPort) copy() SVCBKeyValue { return &SVCBPort{s.Port} } -func (s *SVCBPort) unpack(b []byte) error { - if len(b) != 2 { - return errors.New("dns: svcbport: port length is not exactly 2 octets") +func (s *SVCBPort) unpack(b *cryptobyte.String) error { + if !b.ReadUint16(&s.Port) { + return errUnpackOverflow } - s.Port = binary.BigEndian.Uint16(b) return nil } @@ -588,16 +581,17 @@ func (s *SVCBIPv4Hint) pack() ([]byte, error) { return b, nil } -func (s *SVCBIPv4Hint) unpack(b []byte) error { - if len(b) == 0 || len(b)%4 != 0 { +func (s *SVCBIPv4Hint) unpack(b *cryptobyte.String) error { + if b.Empty() || len(*b)%4 != 0 { return errors.New("dns: svcbipv4hint: ipv4 address byte array length is not a multiple of 4") } - b = cloneSlice(b) - x := make([]net.IP, 0, len(b)/4) - for i := 0; i < len(b); i += 4 { - x = append(x, net.IP(b[i:i+4])) + bb := cloneSlice(*b) + b.Skip(len(*b)) + hints := make([]net.IP, 0, len(bb)/4) + for i := 0; i < len(bb); i += 4 { + hints = append(hints, net.IP(bb[i:i+4])) } - s.Hint = x + s.Hint = hints return nil } @@ -667,8 +661,9 @@ func (s *SVCBECHConfig) copy() SVCBKeyValue { return &SVCBECHConfig{cloneSlice(s.ECH)} } -func (s *SVCBECHConfig) unpack(b []byte) error { - s.ECH = cloneSlice(b) +func (s *SVCBECHConfig) unpack(b *cryptobyte.String) error { + s.ECH = cloneSlice(*b) + b.Skip(len(*b)) return nil } @@ -710,20 +705,21 @@ func (s *SVCBIPv6Hint) pack() ([]byte, error) { return b, nil } -func (s *SVCBIPv6Hint) unpack(b []byte) error { - if len(b) == 0 || len(b)%16 != 0 { +func (s *SVCBIPv6Hint) unpack(b *cryptobyte.String) error { + if b.Empty() || len(*b)%16 != 0 { return errors.New("dns: svcbipv6hint: ipv6 address byte array length not a multiple of 16") } - b = cloneSlice(b) - x := make([]net.IP, 0, len(b)/16) - for i := 0; i < len(b); i += 16 { - ip := net.IP(b[i : i+16]) + bb := cloneSlice(*b) + b.Skip(len(*b)) + hints := make([]net.IP, 0, len(bb)/16) + for i := 0; i < len(bb); i += 16 { + ip := net.IP(bb[i : i+16]) if ip.To4() != nil { return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4") } - x = append(x, ip) + hints = append(hints, ip) } - s.Hint = x + s.Hint = hints return nil } @@ -796,8 +792,9 @@ func (s *SVCBDoHPath) String() string { return svcbParamToStr([]byte(s.Te func (s *SVCBDoHPath) len() int { return len(s.Template) } func (s *SVCBDoHPath) pack() ([]byte, error) { return []byte(s.Template), nil } -func (s *SVCBDoHPath) unpack(b []byte) error { - s.Template = string(b) +func (s *SVCBDoHPath) unpack(b *cryptobyte.String) error { + s.Template = string(*b) + b.Skip(len(*b)) return nil } @@ -836,8 +833,9 @@ func (s *SVCBLocal) String() string { return svcbParamToStr(s.Data) } func (s *SVCBLocal) pack() ([]byte, error) { return cloneSlice(s.Data), nil } func (s *SVCBLocal) len() int { return len(s.Data) } -func (s *SVCBLocal) unpack(b []byte) error { - s.Data = cloneSlice(b) +func (s *SVCBLocal) unpack(b *cryptobyte.String) error { + s.Data = cloneSlice(*b) + b.Skip(len(*b)) return nil } diff --git a/svcb_test.go b/svcb_test.go index 63a40102c..11aed8d5d 100644 --- a/svcb_test.go +++ b/svcb_test.go @@ -2,6 +2,8 @@ package dns import ( "testing" + + "golang.org/x/crypto/cryptobyte" ) // This tests everything valid about SVCB but parsing. @@ -50,7 +52,7 @@ func TestSVCB(t *testing.T) { if len(b) != int(kv.len()) { t.Errorf("expected packed svc value %s to be of length %d but got %d", o.key, int(kv.len()), len(b)) } - err = kv.unpack(b) + err = kv.unpack((*cryptobyte.String)(&b)) if err != nil { t.Error("failed to unpack value of svc pair: ", o.key, err) continue @@ -70,10 +72,12 @@ func TestDecodeBadSVCB(t *testing.T) { key: SVCB_ALPN, data: []byte{3, 0, 0}, // There aren't three octets after 3 }, - { + // The caller is responsible for ensuring the buffer is empty after + // unpacking, see unpackDataSVCB. + /*{ key: SVCB_NO_DEFAULT_ALPN, data: []byte{0}, - }, + },*/ { key: SVCB_PORT, data: []byte{}, @@ -88,7 +92,8 @@ func TestDecodeBadSVCB(t *testing.T) { }, } for _, o := range svcbs { - err := makeSVCBKeyValue(SVCBKey(o.key)).unpack(o.data) + data := o.data + err := makeSVCBKeyValue(SVCBKey(o.key)).unpack((*cryptobyte.String)(&data)) if err == nil { t.Error("accepted invalid svc value with key ", SVCBKey(o.key).String()) }