diff --git a/capsule.go b/capsule.go index 93472f2..2860999 100644 --- a/capsule.go +++ b/capsule.go @@ -28,6 +28,10 @@ type AssignedAddress struct { IPPrefix netip.Prefix } +func (a AssignedAddress) len() int { + return quicvarint.Len(a.RequestID) + 1 + a.IPPrefix.Addr().BitLen()/8 + 1 +} + // addressRequestCapsule represents an ADDRESS_REQUEST capsule type addressRequestCapsule struct { RequestedAddresses []RequestedAddress @@ -39,6 +43,10 @@ type RequestedAddress struct { IPPrefix netip.Prefix } +func (r RequestedAddress) len() int { + return quicvarint.Len(r.RequestID) + 1 + r.IPPrefix.Addr().BitLen()/8 + 1 +} + func parseAddressAssignCapsule(r io.Reader) (*addressAssignCapsule, error) { var assignedAddresses []AssignedAddress for { @@ -54,6 +62,31 @@ func parseAddressAssignCapsule(r io.Reader) (*addressAssignCapsule, error) { return &addressAssignCapsule{AssignedAddresses: assignedAddresses}, nil } +func (c *addressAssignCapsule) marshal(w io.Writer) error { + totalLen := 0 + for _, addr := range c.AssignedAddresses { + totalLen += addr.len() + } + + buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressAssign))+quicvarint.Len(uint64(totalLen))+totalLen) + buf = quicvarint.Append(buf, uint64(capsuleTypeAddressAssign)) + buf = quicvarint.Append(buf, uint64(totalLen)) + + for _, addr := range c.AssignedAddresses { + buf = quicvarint.Append(buf, addr.RequestID) + if addr.IPPrefix.Addr().Is4() { + buf = append(buf, 4) + } else { + buf = append(buf, 6) + } + buf = append(buf, addr.IPPrefix.Addr().AsSlice()...) + buf = append(buf, byte(addr.IPPrefix.Bits())) + } + + _, err := w.Write(buf) + return err +} + func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) { var requestedAddresses []RequestedAddress for { @@ -69,6 +102,31 @@ func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) { return &addressRequestCapsule{RequestedAddresses: requestedAddresses}, nil } +func (c *addressRequestCapsule) marshal(w io.Writer) error { + var totalLen int + for _, addr := range c.RequestedAddresses { + totalLen += addr.len() + } + + buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressRequest))+quicvarint.Len(uint64(totalLen))+totalLen) + buf = quicvarint.Append(buf, uint64(capsuleTypeAddressRequest)) + buf = quicvarint.Append(buf, uint64(totalLen)) + + for _, addr := range c.RequestedAddresses { + buf = quicvarint.Append(buf, addr.RequestID) + if addr.IPPrefix.Addr().Is4() { + buf = append(buf, 4) + } else { + buf = append(buf, 6) + } + buf = append(buf, addr.IPPrefix.Addr().AsSlice()...) + buf = append(buf, byte(addr.IPPrefix.Bits())) + } + + _, err := w.Write(buf) + return err +} + func parseAddress(r io.Reader) (requestID uint64, prefix netip.Prefix, _ error) { vr := quicvarint.NewReader(r) requestID, err := quicvarint.Read(vr) @@ -122,6 +180,8 @@ type IPAddressRange struct { IPProtocol uint8 } +func (r IPAddressRange) len() int { return 1 + r.StartIP.BitLen()/8 + r.EndIP.BitLen()/8 + 1 } + func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, error) { var ranges []IPAddressRange for { @@ -137,6 +197,31 @@ func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, er return &routeAdvertisementCapsule{IPAddressRanges: ranges}, nil } +func (c *routeAdvertisementCapsule) marshal(w io.Writer) error { + var totalLen int + for _, ipRange := range c.IPAddressRanges { + totalLen += ipRange.len() + } + + buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeRouteAdvertisement))+quicvarint.Len(uint64(totalLen))+totalLen) + buf = quicvarint.Append(buf, uint64(capsuleTypeRouteAdvertisement)) + buf = quicvarint.Append(buf, uint64(totalLen)) + + for _, ipRange := range c.IPAddressRanges { + if ipRange.StartIP.Is4() { + buf = append(buf, 4) + } else { + buf = append(buf, 6) + } + buf = append(buf, ipRange.StartIP.AsSlice()...) + buf = append(buf, ipRange.EndIP.AsSlice()...) + buf = append(buf, ipRange.IPProtocol) + } + + _, err := w.Write(buf) + return err +} + func parseIPAddressRange(r io.Reader) (IPAddressRange, error) { var ipVersion uint8 if err := binary.Read(r, binary.LittleEndian, &ipVersion); err != nil { diff --git a/capsule_test.go b/capsule_test.go index 93b4763..95a2ecd 100644 --- a/capsule_test.go +++ b/capsule_test.go @@ -43,6 +43,24 @@ func TestParseAddressAssignCapsule(t *testing.T) { require.Zero(t, r.Len()) } +func TestWriteAddressAssignCapsule(t *testing.T) { + c := &addressAssignCapsule{ + AssignedAddresses: []AssignedAddress{ + {RequestID: 1337, IPPrefix: netip.MustParsePrefix("1.2.3.0/24")}, + {RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")}, + }, + } + buf := &bytes.Buffer{} + require.NoError(t, c.marshal(buf)) + typ, cr, err := http3.ParseCapsule(buf) + require.NoError(t, err) + require.Equal(t, capsuleTypeAddressAssign, typ) + parsed, err := parseAddressAssignCapsule(cr) + require.NoError(t, err) + require.Equal(t, c, parsed) + require.Zero(t, buf.Len()) +} + func TestParseAddressAssignCapsuleInvalid(t *testing.T) { testParseAddressCapsuleInvalid(t, capsuleTypeAddressAssign, func(r io.Reader) error { _, err := parseAddressAssignCapsule(quicvarint.NewReader(r)) @@ -156,6 +174,24 @@ func TestParseAddressRequestCapsule(t *testing.T) { require.Zero(t, r.Len()) } +func TestWriteAddressRequestCapsule(t *testing.T) { + c := &addressRequestCapsule{ + RequestedAddresses: []RequestedAddress{ + {RequestID: 1337, IPPrefix: netip.MustParsePrefix("1.2.3.0/24")}, + {RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")}, + }, + } + buf := &bytes.Buffer{} + require.NoError(t, c.marshal(buf)) + typ, cr, err := http3.ParseCapsule(buf) + require.NoError(t, err) + require.Equal(t, capsuleTypeAddressRequest, typ) + parsed, err := parseAddressRequestCapsule(cr) + require.NoError(t, err) + require.Equal(t, c, parsed) + require.Zero(t, buf.Len()) +} + func TestParseAddressRequestCapsuleInvalid(t *testing.T) { testParseAddressCapsuleInvalid(t, capsuleTypeAddressRequest, func(r io.Reader) error { _, err := parseAddressRequestCapsule(quicvarint.NewReader(r)) @@ -194,6 +230,24 @@ func TestParseRouteAdvertisementCapsule(t *testing.T) { require.Zero(t, r.Len()) } +func TestWriteRouteAdvertisementCapsule(t *testing.T) { + c := &routeAdvertisementCapsule{ + IPAddressRanges: []IPAddressRange{ + {StartIP: netip.MustParseAddr("1.1.1.1"), EndIP: netip.MustParseAddr("1.2.3.4"), IPProtocol: 13}, + {StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37}, + }, + } + buf := &bytes.Buffer{} + require.NoError(t, c.marshal(buf)) + typ, cr, err := http3.ParseCapsule(buf) + require.NoError(t, err) + require.Equal(t, capsuleTypeRouteAdvertisement, typ) + parsed, err := parseRouteAdvertisementCapsule(cr) + require.NoError(t, err) + require.Equal(t, c, parsed) + require.Zero(t, buf.Len()) +} + func TestParseRouteAdvertisementCapsuleInvalid(t *testing.T) { t.Run("invalid IP version", func(t *testing.T) { iprange1 := []byte{5} // IPv5