Skip to content

Commit

Permalink
Use cryptobyte.String to unpack EDNS0 and SVCB values
Browse files Browse the repository at this point in the history
This unpacking is more strict than previous. It now validates that no
trailing data exists.
  • Loading branch information
tmthrgd committed Nov 7, 2023
1 parent 00fe85c commit a205faf
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 153 deletions.
152 changes: 86 additions & 66 deletions edns.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"net"
"strconv"

"golang.org/x/crypto/cryptobyte"
)

// EDNS0 Option codes.
Expand Down Expand Up @@ -214,9 +216,8 @@ type EDNS0 interface {
Option() uint16
// pack returns the bytes of the option data.
pack() ([]byte, error)
// unpack sets the data as found in the buffer. Is also sets
// the length of the slice as the length of the option data.
unpack([]byte) error
// unpack sets the data as found in the buffer.
unpack(*cryptobyte.String) error
// String returns the string representation of the option.
String() string
// copy returns a deep-copy of the option.
Expand Down Expand Up @@ -248,11 +249,16 @@ func (e *EDNS0_NSID) pack() ([]byte, error) {
return h, nil
}

func (e *EDNS0_NSID) unpack(b *cryptobyte.String) error {
e.Nsid = hex.EncodeToString(*b)
b.Skip(len(*b))
return nil
}

// Option implements the EDNS0 interface.
func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code.
func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil }
func (e *EDNS0_NSID) String() string { return e.Nsid }
func (e *EDNS0_NSID) copy() EDNS0 { return &EDNS0_NSID{e.Code, e.Nsid} }
func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code.
func (e *EDNS0_NSID) String() string { return e.Nsid }
func (e *EDNS0_NSID) copy() EDNS0 { return &EDNS0_NSID{e.Code, e.Nsid} }

// EDNS0_SUBNET is the subnet option that is used to give the remote nameserver
// an idea of where the client lives. See RFC 7871. It can then give back a different
Expand Down Expand Up @@ -323,13 +329,12 @@ func (e *EDNS0_SUBNET) pack() ([]byte, error) {
return b, nil
}

func (e *EDNS0_SUBNET) unpack(b []byte) error {
if len(b) < 4 {
return ErrBuf
func (e *EDNS0_SUBNET) unpack(b *cryptobyte.String) error {
if !b.ReadUint16(&e.Family) ||
!b.ReadUint8(&e.SourceNetmask) ||
!b.ReadUint8(&e.SourceScope) {
return errUnpackOverflow
}
e.Family = binary.BigEndian.Uint16(b)
e.SourceNetmask = b[2]
e.SourceScope = b[3]
switch e.Family {
case 0:
// "dig" sets AddressFamily to 0 if SourceNetmask is also 0
Expand All @@ -343,14 +348,14 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
return errors.New("dns: bad netmask")
}
addr := make(net.IP, net.IPv4len)
copy(addr, b[4:])
b.Skip(copy(addr, *b))
e.Address = addr.To16()
case 2:
if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 {
return errors.New("dns: bad netmask")
}
addr := make(net.IP, net.IPv6len)
copy(addr, b[4:])
b.Skip(copy(addr, *b))
e.Address = addr
default:
return errors.New("dns: bad address family")
Expand Down Expand Up @@ -411,11 +416,16 @@ func (e *EDNS0_COOKIE) pack() ([]byte, error) {
return h, nil
}

func (e *EDNS0_COOKIE) unpack(b *cryptobyte.String) error {
e.Cookie = hex.EncodeToString(*b)
b.Skip(len(*b))
return nil
}

// Option implements the EDNS0 interface.
func (e *EDNS0_COOKIE) Option() uint16 { return EDNS0COOKIE }
func (e *EDNS0_COOKIE) unpack(b []byte) error { e.Cookie = hex.EncodeToString(b); return nil }
func (e *EDNS0_COOKIE) String() string { return e.Cookie }
func (e *EDNS0_COOKIE) copy() EDNS0 { return &EDNS0_COOKIE{e.Code, e.Cookie} }
func (e *EDNS0_COOKIE) Option() uint16 { return EDNS0COOKIE }
func (e *EDNS0_COOKIE) String() string { return e.Cookie }
func (e *EDNS0_COOKIE) copy() EDNS0 { return &EDNS0_COOKIE{e.Code, e.Cookie} }

// The EDNS0_UL (Update Lease) (draft RFC) option is used to tell the server to set
// an expiration on an update RR. This is helpful for clients that cannot clean
Expand Down Expand Up @@ -453,16 +463,13 @@ func (e *EDNS0_UL) pack() ([]byte, error) {
return b, nil
}

func (e *EDNS0_UL) unpack(b []byte) error {
switch len(b) {
case 4:
e.KeyLease = 0
case 8:
e.KeyLease = binary.BigEndian.Uint32(b[4:])
default:
return ErrBuf
func (e *EDNS0_UL) unpack(b *cryptobyte.String) error {
if !b.ReadUint32(&e.Lease) {
return errUnpackOverflow
}
if !b.Empty() && !b.ReadUint32(&e.KeyLease) {
return errUnpackOverflow
}
e.Lease = binary.BigEndian.Uint32(b)
return nil
}

Expand Down Expand Up @@ -490,15 +497,14 @@ func (e *EDNS0_LLQ) pack() ([]byte, error) {
return b, nil
}

func (e *EDNS0_LLQ) unpack(b []byte) error {
if len(b) < 18 {
return ErrBuf
func (e *EDNS0_LLQ) unpack(b *cryptobyte.String) error {
if !b.ReadUint16(&e.Version) ||
!b.ReadUint16(&e.Opcode) ||
!b.ReadUint16(&e.Error) ||
!b.ReadUint64(&e.Id) ||
!b.ReadUint32(&e.LeaseLife) {
return errUnpackOverflow
}
e.Version = binary.BigEndian.Uint16(b[0:])
e.Opcode = binary.BigEndian.Uint16(b[2:])
e.Error = binary.BigEndian.Uint16(b[4:])
e.Id = binary.BigEndian.Uint64(b[6:])
e.LeaseLife = binary.BigEndian.Uint32(b[14:])
return nil
}

Expand All @@ -522,7 +528,12 @@ type EDNS0_DAU struct {
// Option implements the EDNS0 interface.
func (e *EDNS0_DAU) Option() uint16 { return EDNS0DAU }
func (e *EDNS0_DAU) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil }
func (e *EDNS0_DAU) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil }

func (e *EDNS0_DAU) unpack(b *cryptobyte.String) error {
e.AlgCode = cloneSlice(*b)
b.Skip(len(*b))
return nil
}

func (e *EDNS0_DAU) String() string {
s := ""
Expand All @@ -546,7 +557,12 @@ type EDNS0_DHU struct {
// Option implements the EDNS0 interface.
func (e *EDNS0_DHU) Option() uint16 { return EDNS0DHU }
func (e *EDNS0_DHU) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil }
func (e *EDNS0_DHU) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil }

func (e *EDNS0_DHU) unpack(b *cryptobyte.String) error {
e.AlgCode = cloneSlice(*b)
b.Skip(len(*b))
return nil
}

func (e *EDNS0_DHU) String() string {
s := ""
Expand All @@ -570,7 +586,12 @@ type EDNS0_N3U struct {
// Option implements the EDNS0 interface.
func (e *EDNS0_N3U) Option() uint16 { return EDNS0N3U }
func (e *EDNS0_N3U) pack() ([]byte, error) { return cloneSlice(e.AlgCode), nil }
func (e *EDNS0_N3U) unpack(b []byte) error { e.AlgCode = cloneSlice(b); return nil }

func (e *EDNS0_N3U) unpack(b *cryptobyte.String) error {
e.AlgCode = cloneSlice(*b)
b.Skip(len(*b))
return nil
}

func (e *EDNS0_N3U) String() string {
// Re-use the hash map
Expand Down Expand Up @@ -606,17 +627,12 @@ func (e *EDNS0_EXPIRE) pack() ([]byte, error) {
return b, nil
}

func (e *EDNS0_EXPIRE) unpack(b []byte) error {
if len(b) == 0 {
// zero-length EXPIRE query, see RFC 7314 Section 2
e.Empty = true
return nil
}
if len(b) < 4 {
return ErrBuf
func (e *EDNS0_EXPIRE) unpack(b *cryptobyte.String) error {
// zero-length EXPIRE query, see RFC 7314 Section 2
e.Empty = b.Empty()
if !b.Empty() && !b.ReadUint32(&e.Expire) {
return errUnpackOverflow
}
e.Expire = binary.BigEndian.Uint32(b)
e.Empty = false
return nil
}

Expand Down Expand Up @@ -660,8 +676,9 @@ func (e *EDNS0_LOCAL) pack() ([]byte, error) {
return cloneSlice(e.Data), nil
}

func (e *EDNS0_LOCAL) unpack(b []byte) error {
e.Data = cloneSlice(b)
func (e *EDNS0_LOCAL) unpack(b *cryptobyte.String) error {
e.Data = cloneSlice(*b)
b.Skip(len(*b))
return nil
}

Expand Down Expand Up @@ -692,13 +709,9 @@ func (e *EDNS0_TCP_KEEPALIVE) pack() ([]byte, error) {
return nil, nil
}

func (e *EDNS0_TCP_KEEPALIVE) unpack(b []byte) error {
switch len(b) {
case 0:
case 2:
e.Timeout = binary.BigEndian.Uint16(b)
default:
return fmt.Errorf("dns: length mismatch, want 0/2 but got %d", len(b))
func (e *EDNS0_TCP_KEEPALIVE) unpack(b *cryptobyte.String) error {
if !b.Empty() && !b.ReadUint16(&e.Timeout) {
return errUnpackOverflow
}
return nil
}
Expand All @@ -725,10 +738,15 @@ type EDNS0_PADDING struct {
// Option implements the EDNS0 interface.
func (e *EDNS0_PADDING) Option() uint16 { return EDNS0PADDING }
func (e *EDNS0_PADDING) pack() ([]byte, error) { return cloneSlice(e.Padding), nil }
func (e *EDNS0_PADDING) unpack(b []byte) error { e.Padding = cloneSlice(b); return nil }
func (e *EDNS0_PADDING) String() string { return fmt.Sprintf("%0X", e.Padding) }
func (e *EDNS0_PADDING) copy() EDNS0 { return &EDNS0_PADDING{cloneSlice(e.Padding)} }

func (e *EDNS0_PADDING) unpack(b *cryptobyte.String) error {
e.Padding = cloneSlice(*b)
b.Skip(len(*b))
return nil
}

// Extended DNS Error Codes (RFC 8914).
const (
ExtendedErrorCodeOther uint16 = iota
Expand Down Expand Up @@ -818,12 +836,12 @@ func (e *EDNS0_EDE) pack() ([]byte, error) {
return b, nil
}

func (e *EDNS0_EDE) unpack(b []byte) error {
if len(b) < 2 {
return ErrBuf
func (e *EDNS0_EDE) unpack(b *cryptobyte.String) error {
if !b.ReadUint16(&e.InfoCode) {
return errUnpackOverflow
}
e.InfoCode = binary.BigEndian.Uint16(b[0:])
e.ExtraText = string(b[2:])
e.ExtraText = string(*b)
b.Skip(len(*b))
return nil
}

Expand All @@ -838,7 +856,9 @@ func (e *EDNS0_ESU) Option() uint16 { return EDNS0ESU }
func (e *EDNS0_ESU) String() string { return e.Uri }
func (e *EDNS0_ESU) copy() EDNS0 { return &EDNS0_ESU{e.Code, e.Uri} }
func (e *EDNS0_ESU) pack() ([]byte, error) { return []byte(e.Uri), nil }
func (e *EDNS0_ESU) unpack(b []byte) error {
e.Uri = string(b)

func (e *EDNS0_ESU) unpack(b *cryptobyte.String) error {
e.Uri = string(*b)
b.Skip(len(*b))
return nil
}
13 changes: 8 additions & 5 deletions edns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"net"
"testing"

"golang.org/x/crypto/cryptobyte"
)

func TestOPTTtl(t *testing.T) {
Expand Down Expand Up @@ -101,7 +103,7 @@ func TestEDNS0_SUBNETUnpack(t *testing.T) {
}

var s2 EDNS0_SUBNET
if err := s2.unpack(b); err != nil {
if err := s2.unpack((*cryptobyte.String)(&b)); err != nil {
t.Fatalf("failed to unpack: %v", err)
}

Expand All @@ -125,8 +127,8 @@ func TestEDNS0_UL(t *testing.T) {
if err != nil {
t.Fatalf("failed to pack: %v", err)
}
actual := EDNS0_UL{EDNS0UL, ^uint32(0), ^uint32(0)}
if err := actual.unpack(b); err != nil {
actual := EDNS0_UL{EDNS0UL, 0, 0}
if err := actual.unpack((*cryptobyte.String)(&b)); err != nil {
t.Fatalf("failed to unpack: %v", err)
}
if expect != actual {
Expand Down Expand Up @@ -213,15 +215,16 @@ func TestEDNS0_TCP_KEEPALIVE_unpack(t *testing.T) {
},
{
name: "invalid",
b: []byte{0, 1, 3},
b: []byte{1},
expectedErr: true,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
b := tc.b
e := &EDNS0_TCP_KEEPALIVE{}
err := e.unpack(tc.b)
err := e.unpack((*cryptobyte.String)(&b))
if err != nil && !tc.expectedErr {
t.Error("failed to unpack, expected no error")
}
Expand Down
Loading

0 comments on commit a205faf

Please sign in to comment.