diff --git a/httpx/access.go b/httpx/access.go index b80b702..ec62e12 100644 --- a/httpx/access.go +++ b/httpx/access.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/pkg/errors" "golang.org/x/net/context" ) @@ -52,3 +53,27 @@ func (c *AccessConfig) Allow(request *http.Request) (bool, error) { } return true, nil } + +// ParseNetworks parses a list of IPs and IP networks (written in CIDR notation) +func ParseNetworks(addrs ...string) ([]net.IP, []*net.IPNet, error) { + ips := make([]net.IP, 0, len(addrs)) + ipNets := make([]*net.IPNet, 0, len(addrs)) + + for _, addr := range addrs { + if strings.Contains(addr, "/") { + _, ipNet, err := net.ParseCIDR(addr) + if err != nil { + return nil, nil, errors.Errorf("couldn't parse '%s' as an IP network", addr) + } + ipNets = append(ipNets, ipNet) + } else { + ip := net.ParseIP(addr) + if ip == nil { + return nil, nil, errors.Errorf("couldn't parse '%s' as an IP address", addr) + } + ips = append(ips, ip) + } + } + + return ips, ipNets, nil +} diff --git a/httpx/access_test.go b/httpx/access_test.go index 5ef996b..73592af 100644 --- a/httpx/access_test.go +++ b/httpx/access_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/nyaruka/gocommon/httpx" - "github.com/stretchr/testify/assert" ) @@ -68,3 +67,32 @@ func TestAccessConfig(t *testing.T) { } } } + +func TestParseNetworkList(t *testing.T) { + privateNetwork1 := &net.IPNet{IP: net.IPv4(10, 0, 0, 0).To4(), Mask: net.CIDRMask(8, 32)} + privateNetwork2 := &net.IPNet{IP: net.IPv4(172, 16, 0, 0).To4(), Mask: net.CIDRMask(12, 32)} + privateNetwork3 := &net.IPNet{IP: net.IPv4(192, 168, 0, 0).To4(), Mask: net.CIDRMask(16, 32)} + + linkLocalIPv4 := &net.IPNet{IP: net.IPv4(169, 254, 0, 0).To4(), Mask: net.CIDRMask(16, 32)} + _, linkLocalIPv6, _ := net.ParseCIDR("fe80::/10") + + // test with mailroom defaults + ips, ipNets, err := httpx.ParseNetworks(`127.0.0.1`, `::1`, `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`, `169.254.0.0/16`, `fe80::/10`) + assert.NoError(t, err) + assert.Equal(t, []net.IP{net.IPv4(127, 0, 0, 1), net.ParseIP(`::1`)}, ips) + assert.Equal(t, []*net.IPNet{privateNetwork1, privateNetwork2, privateNetwork3, linkLocalIPv4, linkLocalIPv6}, ipNets) + + // test with empty + ips, ipNets, err = httpx.ParseNetworks() + assert.NoError(t, err) + assert.Equal(t, []net.IP{}, ips) + assert.Equal(t, []*net.IPNet{}, ipNets) + + // test with invalid IP + _, _, err = httpx.ParseNetworks(`127.0.1`) + assert.EqualError(t, err, `couldn't parse '127.0.1' as an IP address`) + + // test with invalid network + _, _, err = httpx.ParseNetworks(`127.0.0.1/x`) + assert.EqualError(t, err, `couldn't parse '127.0.0.1/x' as an IP network`) +}