diff --git a/ipv6.go b/ipv6.go index e7d0383..7e16994 100644 --- a/ipv6.go +++ b/ipv6.go @@ -4,8 +4,11 @@ import ( "fmt" "math/big" "net" + "sort" ) +const widthUInt128 = 128 + // ipv6ToUInt128 converts an IPv6 address to an unsigned 128-bit integer. func ipv6ToUInt128(ip net.IP) *big.Int { return big.NewInt(0).SetBytes(ip) @@ -21,18 +24,22 @@ func copyUInt128(x *big.Int) *big.Int { return big.NewInt(0).Set(x) } +// hostmask6 returns the hostmask for the specified prefix. +func hostmask6(prefix uint) *big.Int { + z := big.NewInt(0) + + z.Lsh(big.NewInt(1), widthUInt128-prefix) + z.Sub(z, big.NewInt(1)) + + return z +} + // broadcast6 returns the broadcast address for the given address and prefix. func broadcast6(addr *big.Int, prefix uint) *big.Int { - z := copyUInt128(addr) + z := big.NewInt(0) - if prefix == 0 { - z, _ = z.SetString("340282366920938463463374607431768211455", 10) - return z - } + z.Or(addr, hostmask6(prefix)) - for i := int(prefix); i < 8*net.IPv6len; i++ { - z = z.SetBit(z, i, 1) - } return z } @@ -52,14 +59,14 @@ func network6(addr *big.Int, prefix uint) *big.Int { // splitRange6 recursively computes the CIDR blocks to cover the range lo to hi. func splitRange6(addr *big.Int, prefix uint, lo, hi *big.Int, cidrs *[]*net.IPNet) error { - if prefix > 128 { + if prefix > widthUInt128 { return fmt.Errorf("Invalid mask size: %d", prefix) } bc := broadcast6(addr, prefix) - fmt.Printf("%v/%v/%v/%v/%v\n", addr, prefix, lo, hi, bc) + fmt.Printf("%v/%v, %v-%v, %v\n", uint128ToIPV6(addr), prefix, uint128ToIPV6(lo), uint128ToIPV6(hi), uint128ToIPV6(bc)) if (lo.Cmp(addr) < 0) || (hi.Cmp(bc) > 0) { - return fmt.Errorf("%v, %v out of range for network %v/%d, broadcast %v", lo, hi, addr, prefix, bc) + return fmt.Errorf("%v, %v out of range for network %v/%d, broadcast %v", uint128ToIPV6(lo), uint128ToIPV6(hi), uint128ToIPV6(addr), prefix, uint128ToIPV6(bc)) } if (lo.Cmp(addr) == 0) && (hi.Cmp(bc) == 0) { @@ -71,7 +78,7 @@ func splitRange6(addr *big.Int, prefix uint, lo, hi *big.Int, cidrs *[]*net.IPNe prefix++ lowerHalf := copyUInt128(addr) upperHalf := copyUInt128(addr) - upperHalf = upperHalf.SetBit(upperHalf, int(prefix), 1) + upperHalf.SetBit(upperHalf, int(widthUInt128 - prefix), 1) if hi.Cmp(upperHalf) < 0 { return splitRange6(lowerHalf, prefix, lo, hi, cidrs) } else if lo.Cmp(upperHalf) >= 0 { @@ -84,3 +91,85 @@ func splitRange6(addr *big.Int, prefix uint, lo, hi *big.Int, cidrs *[]*net.IPNe return splitRange6(upperHalf, prefix, upperHalf, hi, cidrs) } } + +// IPv6 CIDR block. + +type cidrBlock6 struct { + first *big.Int + last *big.Int +} + +type cidrBlock6s []*cidrBlock6 + +// newBlock6 returns a new IPv6 CIDR block. +func newBlock6(ip net.IP, mask net.IPMask) *cidrBlock6 { + var block cidrBlock6 + + block.first = ipv6ToUInt128(ip) + prefix, _ := mask.Size() + block.last = broadcast6(block.first, uint(prefix)) + + return &block +} + +// Sort interface. + +func (c cidrBlock6s) Len() int { + return len(c) +} + +func (c cidrBlock6s) Less(i, j int) bool { + lhs := c[i] + rhs := c[j] + + // By last IP in the range. + if lhs.last.Cmp(rhs.last) < 0 { + return true + } else if lhs.last.Cmp(rhs.last) > 0 { + return false + } + + // Then by first IP in the range. + if lhs.first.Cmp(rhs.first) < 0 { + return true + } else if lhs.first.Cmp(rhs.first) > 0 { + return false + } + + return false +} + +func (c cidrBlock6s) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +// merge6 accepts a list of IPv6 networks and merges them into the smallest possible list of IPNets. +// It merges adjacent subnets where possible, those contained within others and removes any duplicates. +func merge6(blocks cidrBlock6s) ([]*net.IPNet, error) { + sort.Sort(blocks) + + // Coalesce overlapping blocks. + for i := len(blocks) - 1; i > 0; i-- { + cmp := blocks[i-1].last + cmp.Add(cmp, big.NewInt(1)) + if blocks[i].first.Cmp(cmp) <= 0 { + blocks[i-1].last = blocks[i].last + if blocks[i].first.Cmp(blocks[i-1].first) < 0 { + blocks[i-1].first = blocks[i].first + } + blocks[i] = nil + } + } + + var merged []*net.IPNet + for _, block := range blocks { + if block == nil { + continue + } + + if err := splitRange6(big.NewInt(0), 0, block.first, block.last, &merged); err != nil { + return nil, err + } + } + return merged, nil +} diff --git a/merge.go b/merge.go index 7668285..31adaaf 100644 --- a/merge.go +++ b/merge.go @@ -4,7 +4,6 @@ package cidrman import ( - "errors" "net" ) @@ -32,20 +31,28 @@ func MergeIPNets(nets []*net.IPNet) ([]*net.IPNet, error) { // Split into IPv4 and IPv6 lists. // Merge the list separately and then combine. var block4s cidrBlock4s + var block6s cidrBlock6s for _, net := range nets { ip4 := net.IP.To4() if ip4 != nil { block4s = append(block4s, newBlock4(ip4, net.Mask)) } else { - return nil, errors.New("Not implemented") + ip6 := net.IP.To16() + block6s = append(block6s, newBlock6(ip6, net.Mask)) } } - merged, err := merge4(block4s) + merged4, err := merge4(block4s) if err != nil { return nil, err } + merged6, err := merge6(block6s) + if err != nil { + return nil, err + } + + merged := append(merged4, merged6...) return merged, nil }