Skip to content

Commit

Permalink
implement marshalling of capsules
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Oct 5, 2024
1 parent e3df00c commit 730cc66
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
85 changes: 85 additions & 0 deletions capsule.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type AssignedAddress struct {
IPPrefix netip.Prefix
}

func (a AssignedAddress) len() int {
return quicvarint.Len(a.RequestID) + 1 + a.IPPrefix.Addr().BitLen()/8 + 1
}

// addressRequestCapsule represents an ADDRESS_REQUEST capsule
type addressRequestCapsule struct {
RequestedAddresses []RequestedAddress
Expand All @@ -39,6 +43,10 @@ type RequestedAddress struct {
IPPrefix netip.Prefix
}

func (r RequestedAddress) len() int {
return quicvarint.Len(r.RequestID) + 1 + r.IPPrefix.Addr().BitLen()/8 + 1
}

func parseAddressAssignCapsule(r io.Reader) (*addressAssignCapsule, error) {
var assignedAddresses []AssignedAddress
for {
Expand All @@ -54,6 +62,31 @@ func parseAddressAssignCapsule(r io.Reader) (*addressAssignCapsule, error) {
return &addressAssignCapsule{AssignedAddresses: assignedAddresses}, nil
}

func (c *addressAssignCapsule) marshal(w io.Writer) error {
totalLen := 0
for _, addr := range c.AssignedAddresses {
totalLen += addr.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressAssign))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeAddressAssign))
buf = quicvarint.Append(buf, uint64(totalLen))

for _, addr := range c.AssignedAddresses {
buf = quicvarint.Append(buf, addr.RequestID)
if addr.IPPrefix.Addr().Is4() {
buf = append(buf, 4)
} else {
buf = append(buf, 6)
}
buf = append(buf, addr.IPPrefix.Addr().AsSlice()...)
buf = append(buf, byte(addr.IPPrefix.Bits()))
}

_, err := w.Write(buf)
return err
}

func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) {
var requestedAddresses []RequestedAddress
for {
Expand All @@ -69,6 +102,31 @@ func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) {
return &addressRequestCapsule{RequestedAddresses: requestedAddresses}, nil
}

func (c *addressRequestCapsule) marshal(w io.Writer) error {
var totalLen int
for _, addr := range c.RequestedAddresses {
totalLen += addr.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressRequest))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeAddressRequest))
buf = quicvarint.Append(buf, uint64(totalLen))

for _, addr := range c.RequestedAddresses {
buf = quicvarint.Append(buf, addr.RequestID)
if addr.IPPrefix.Addr().Is4() {
buf = append(buf, 4)
} else {
buf = append(buf, 6)
}
buf = append(buf, addr.IPPrefix.Addr().AsSlice()...)
buf = append(buf, byte(addr.IPPrefix.Bits()))
}

_, err := w.Write(buf)
return err
}

func parseAddress(r io.Reader) (requestID uint64, prefix netip.Prefix, _ error) {
vr := quicvarint.NewReader(r)
requestID, err := quicvarint.Read(vr)
Expand Down Expand Up @@ -122,6 +180,8 @@ type IPAddressRange struct {
IPProtocol uint8
}

func (r IPAddressRange) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 }

func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, error) {
var ranges []IPAddressRange
for {
Expand All @@ -137,6 +197,31 @@ func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, er
return &routeAdvertisementCapsule{IPAddressRanges: ranges}, nil
}

func (c *routeAdvertisementCapsule) marshal(w io.Writer) error {
var totalLen int
for _, ipRange := range c.IPAddressRanges {
totalLen += ipRange.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeRouteAdvertisement))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeRouteAdvertisement))
buf = quicvarint.Append(buf, uint64(totalLen))

for _, ipRange := range c.IPAddressRanges {
if ipRange.StartIP.Is4() {
buf = append(buf, 4)
} else {
buf = append(buf, 6)
}
buf = append(buf, ipRange.StartIP.AsSlice()...)
buf = append(buf, ipRange.EndIP.AsSlice()...)
buf = append(buf, ipRange.IPProtocol)
}

_, err := w.Write(buf)
return err
}

func parseIPAddressRange(r io.Reader) (IPAddressRange, error) {
var ipVersion uint8
if err := binary.Read(r, binary.LittleEndian, &ipVersion); err != nil {
Expand Down
54 changes: 54 additions & 0 deletions capsule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ func TestParseAddressAssignCapsule(t *testing.T) {
require.Zero(t, r.Len())
}

func TestWriteAddressAssignCapsule(t *testing.T) {
c := &addressAssignCapsule{
AssignedAddresses: []AssignedAddress{
{RequestID: 1337, IPPrefix: netip.MustParsePrefix("1.2.3.0/24")},
{RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
require.NoError(t, err)
require.Equal(t, capsuleTypeAddressAssign, typ)
parsed, err := parseAddressAssignCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
}

func TestParseAddressAssignCapsuleInvalid(t *testing.T) {
testParseAddressCapsuleInvalid(t, capsuleTypeAddressAssign, func(r io.Reader) error {
_, err := parseAddressAssignCapsule(quicvarint.NewReader(r))
Expand Down Expand Up @@ -156,6 +174,24 @@ func TestParseAddressRequestCapsule(t *testing.T) {
require.Zero(t, r.Len())
}

func TestWriteAddressRequestCapsule(t *testing.T) {
c := &addressRequestCapsule{
RequestedAddresses: []RequestedAddress{
{RequestID: 1337, IPPrefix: netip.MustParsePrefix("1.2.3.0/24")},
{RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
require.NoError(t, err)
require.Equal(t, capsuleTypeAddressRequest, typ)
parsed, err := parseAddressRequestCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
}

func TestParseAddressRequestCapsuleInvalid(t *testing.T) {
testParseAddressCapsuleInvalid(t, capsuleTypeAddressRequest, func(r io.Reader) error {
_, err := parseAddressRequestCapsule(quicvarint.NewReader(r))
Expand Down Expand Up @@ -194,6 +230,24 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) {
require.Zero(t, r.Len())
}

func TestWriteRouteAdvertisementCapsule(t *testing.T) {
c := &routeAdvertisementCapsule{
IPAddressRanges: []IPAddressRange{
{StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13},
{StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
require.NoError(t, err)
require.Equal(t, capsuleTypeRouteAdvertisement, typ)
parsed, err := parseRouteAdvertisementCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
}

func TestParseRouteAdvertisementCapsuleInvalid(t *testing.T) {
t.Run("invalid IP version", func(t *testing.T) {
iprange1 := []byte{5} // IPv5
Expand Down

0 comments on commit 730cc66

Please sign in to comment.