diff --git a/dhcpv4/async/client.go b/dhcpv4/async/client.go deleted file mode 100644 index 54b500d6..00000000 --- a/dhcpv4/async/client.go +++ /dev/null @@ -1,217 +0,0 @@ -package async - -import ( - "context" - "fmt" - "log" - "net" - "sync" - "time" - - promise "github.com/fanliao/go-promise" - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/client4" -) - -// Default ports -const ( - DefaultServerPort = 67 - DefaultClientPort = 68 -) - -// Client implements an asynchronous DHCPv4 client -// It doesn't use the broadcast socket! Which means it should be used only when -// the network is already established. -// https://github.com/insomniacslk/dhcp/issues/143 -type Client struct { - ReadTimeout time.Duration - WriteTimeout time.Duration - LocalAddr net.Addr - RemoteAddr net.Addr - IgnoreErrors bool - - connection *net.UDPConn - cancel context.CancelFunc - stopping *sync.WaitGroup - receiveQueue chan *dhcpv4.DHCPv4 - sendQueue chan *dhcpv4.DHCPv4 - packetsLock sync.Mutex - packets map[dhcpv4.TransactionID]*promise.Promise - errors chan error -} - -// NewClient creates an asynchronous client -func NewClient() *Client { - return &Client{ - ReadTimeout: client4.DefaultReadTimeout, - WriteTimeout: client4.DefaultWriteTimeout, - } -} - -// Open starts the client. The requests made with Send function call are first -// put to the buffered channel and dispatched in FIFO order. BufferSize -// indicates the number of packets that can be waiting to be send before -// blocking the caller exectution. -func (c *Client) Open(bufferSize int) error { - var ( - addr *net.UDPAddr - ok bool - err error - ) - - if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok { - return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr) - } - - // prepare the socket to listen on for replies - c.connection, err = net.ListenUDP("udp4", addr) - if err != nil { - return err - } - c.stopping = new(sync.WaitGroup) - c.sendQueue = make(chan *dhcpv4.DHCPv4, bufferSize) - c.receiveQueue = make(chan *dhcpv4.DHCPv4, bufferSize) - c.packets = make(map[dhcpv4.TransactionID]*promise.Promise) - c.packetsLock = sync.Mutex{} - c.errors = make(chan error) - - var ctx context.Context - ctx, c.cancel = context.WithCancel(context.Background()) - go c.receiverLoop(ctx) - go c.senderLoop(ctx) - - return nil -} - -// Close stops the client -func (c *Client) Close() { - // Wait for sender and receiver loops - c.stopping.Add(2) - c.cancel() - c.stopping.Wait() - - close(c.sendQueue) - close(c.receiveQueue) - close(c.errors) - - c.connection.Close() -} - -// Errors returns a channel where runtime errors are posted -func (c *Client) Errors() <-chan error { - return c.errors -} - -func (c *Client) addError(err error) { - if !c.IgnoreErrors { - c.errors <- err - } -} - -func (c *Client) receiverLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.receiveQueue: - c.receive(packet) - } - } -} - -func (c *Client) senderLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.sendQueue: - c.send(packet) - } - } -} - -func (c *Client) send(packet *dhcpv4.DHCPv4) { - c.packetsLock.Lock() - p := c.packets[packet.TransactionID] - c.packetsLock.Unlock() - - raddr, err := c.remoteAddr() - if err != nil { - _ = p.Reject(err) - return - } - - if err := c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil { - log.Printf("Warning: cannot set write deadline: %v", err) - return - } - _, err = c.connection.WriteTo(packet.ToBytes(), raddr) - if err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot write to %s: %v", raddr, err) - return - } - - c.receiveQueue <- packet -} - -func (c *Client) receive(_ *dhcpv4.DHCPv4) { - var ( - oobdata = []byte{} - received *dhcpv4.DHCPv4 - ) - - if err := c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - log.Printf("Warning: cannot set write deadline: %v", err) - return - } - for { - buffer := make([]byte, client4.MaxUDPReceivedPacketSize) - n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) - if err != nil { - if err, ok := err.(net.Error); !ok || !err.Timeout() { - c.addError(fmt.Errorf("Error receiving the message: %s", err)) - } - return - } - received, err = dhcpv4.FromBytes(buffer[:n]) - if err == nil { - break - } - } - - c.packetsLock.Lock() - if p, ok := c.packets[received.TransactionID]; ok { - delete(c.packets, received.TransactionID) - _ = p.Resolve(received) - } - c.packetsLock.Unlock() -} - -func (c *Client) remoteAddr() (*net.UDPAddr, error) { - if c.RemoteAddr == nil { - return &net.UDPAddr{IP: net.IPv4bcast, Port: DefaultServerPort}, nil - } - - if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { - return addr, nil - } - return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr) -} - -// Send inserts a message to the queue to be sent asynchronously. -// Returns a future which resolves to response and error. -func (c *Client) Send(message *dhcpv4.DHCPv4, modifiers ...dhcpv4.Modifier) *promise.Future { - for _, mod := range modifiers { - mod(message) - } - - p := promise.NewPromise() - c.packetsLock.Lock() - c.packets[message.TransactionID] = p - c.packetsLock.Unlock() - c.sendQueue <- message - return p.Future -} diff --git a/dhcpv4/async/client_test.go b/dhcpv4/async/client_test.go deleted file mode 100644 index 2269d57c..00000000 --- a/dhcpv4/async/client_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package async - -import ( - "context" - "net" - "testing" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/client4" - "github.com/stretchr/testify/require" -) - -// server creates a server which responds with a predefined response -func serve(ctx context.Context, addr *net.UDPAddr, response *dhcpv4.DHCPv4) error { - conn, err := net.ListenUDP("udp4", addr) - if err != nil { - return err - } - go func() { - defer conn.Close() - oobdata := []byte{} - buffer := make([]byte, client4.MaxUDPReceivedPacketSize) - for { - select { - case <-ctx.Done(): - return - default: - if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) - if err != nil { - continue - } - _, err = dhcpv4.FromBytes(buffer[:n]) - if err != nil { - continue - } - if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - _, err = conn.WriteTo(response.ToBytes(), src) - if err != nil { - continue - } - } - } - }() - return nil -} - -func TestNewClient(t *testing.T) { - c := NewClient() - require.NotNil(t, c) - require.Equal(t, c.ReadTimeout, client4.DefaultReadTimeout) - require.Equal(t, c.ReadTimeout, client4.DefaultWriteTimeout) -} - -func TestOpenInvalidAddrFailes(t *testing.T) { - c := NewClient() - err := c.Open(512) - require.Error(t, err) -} - -// This test uses port 15438 so please make sure its not used before running -func TestOpenClose(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") - require.NoError(t, err) - c.LocalAddr = addr - err = c.Open(512) - require.NoError(t, err) - defer c.Close() -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSendTimeout(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15439") - require.NoError(t, err) - c.ReadTimeout = 50 * time.Millisecond - c.WriteTimeout = 50 * time.Millisecond - c.LocalAddr = addr - c.RemoteAddr = remote - err = c.Open(512) - require.NoError(t, err) - defer c.Close() - m, err := dhcpv4.New() - require.NoError(t, err) - _, err, timeout := c.Send(m).GetOrTimeout(200) - require.NoError(t, err) - require.True(t, timeout) -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSend(t *testing.T) { - m, err := dhcpv4.New() - require.NoError(t, err) - require.NotNil(t, m) - - c := NewClient() - addr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp4", "127.0.0.1:15439") - require.NoError(t, err) - c.LocalAddr = addr - c.RemoteAddr = remote - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err = serve(ctx, remote, m) - require.NoError(t, err) - - err = c.Open(16) - require.NoError(t, err) - defer c.Close() - - f := c.Send(m) - response, err, timeout := f.GetOrTimeout(2000) - r, ok := response.(*dhcpv4.DHCPv4) - require.True(t, ok) - require.False(t, timeout) - require.NoError(t, err) - require.Equal(t, m.TransactionID, r.TransactionID) -} diff --git a/dhcpv4/bsdp/client.go b/dhcpv4/bsdp/client.go deleted file mode 100644 index 5d0e667e..00000000 --- a/dhcpv4/bsdp/client.go +++ /dev/null @@ -1,75 +0,0 @@ -package bsdp - -import ( - "errors" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/client4" -) - -// Client represents a BSDP client that can perform BSDP exchanges via the -// broadcast address. -type Client struct { - client4.Client -} - -// NewClient constructs a new client with default read and write timeouts from -// dhcpv4.Client. -func NewClient() *Client { - return &Client{Client: client4.Client{}} -} - -// Exchange runs a full BSDP exchange (Inform[list], Ack, Inform[select], -// Ack). Returns a list of DHCPv4 structures representing the exchange. -func (c *Client) Exchange(ifname string) ([]*Packet, error) { - conversation := make([]*Packet, 0) - - // Get our file descriptor for the broadcast socket. - sendFd, err := client4.MakeBroadcastSocket(ifname) - if err != nil { - return conversation, err - } - recvFd, err := client4.MakeListeningSocket(ifname) - if err != nil { - return conversation, err - } - - // INFORM[LIST] - informList, err := NewInformListForInterface(ifname, dhcpv4.ClientPort) - if err != nil { - return conversation, err - } - conversation = append(conversation, informList) - - // ACK[LIST] - ackForList, err := c.Client.SendReceive(sendFd, recvFd, informList.v4(), dhcpv4.MessageTypeAck) - if err != nil { - return conversation, err - } - - // Rewrite vendor-specific option for pretty printing. - conversation = append(conversation, PacketFor(ackForList)) - - // Parse boot images sent back by server - bootImages, err := ParseBootImageListFromAck(ackForList) - if err != nil { - return conversation, err - } - if len(bootImages) == 0 { - return conversation, errors.New("got no BootImages from server") - } - - // INFORM[SELECT] - informSelect, err := InformSelectForAck(PacketFor(ackForList), dhcpv4.ClientPort, bootImages[0]) - if err != nil { - return conversation, err - } - conversation = append(conversation, informSelect) - - // ACK[SELECT] - ackForSelect, err := c.Client.SendReceive(sendFd, recvFd, informSelect.v4(), dhcpv4.MessageTypeAck) - if err != nil { - return conversation, err - } - return append(conversation, PacketFor(ackForSelect)), nil -} diff --git a/dhcpv4/client4/client.go b/dhcpv4/client4/client.go deleted file mode 100644 index f2747147..00000000 --- a/dhcpv4/client4/client.go +++ /dev/null @@ -1,371 +0,0 @@ -// Package client4 is deprecated. Use "nclient4" instead. -package client4 - -import ( - "encoding/binary" - "errors" - "fmt" - "log" - "net" - "reflect" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" - "golang.org/x/net/ipv4" - "golang.org/x/sys/unix" -) - -// MaxUDPReceivedPacketSize is the (arbitrary) maximum UDP packet size supported -// by this library. Theoretically could be up to 65kb. -const ( - MaxUDPReceivedPacketSize = 8192 -) - -var ( - // DefaultReadTimeout is the time to wait after listening in which the - // exchange is considered failed. - DefaultReadTimeout = 3 * time.Second - - // DefaultWriteTimeout is the time to wait after sending in which the - // exchange is considered failed. - DefaultWriteTimeout = 3 * time.Second -) - -// Client is the object that actually performs the DHCP exchange. It currently -// only has read and write timeout values, plus (optional) local and remote -// addresses. -type Client struct { - ReadTimeout, WriteTimeout time.Duration - RemoteAddr net.Addr - LocalAddr net.Addr -} - -// NewClient generates a new client to perform a DHCP exchange with, setting the -// read and write timeout fields to defaults. -func NewClient() *Client { - return &Client{ - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, - } -} - -// MakeRawUDPPacket converts a payload (a serialized DHCPv4 packet) into a -// raw UDP packet for the specified serverAddr from the specified clientAddr. -func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { - udp := make([]byte, 8) - binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) - binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) - binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload))) - binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum - - h := ipv4.Header{ - Version: 4, - Len: 20, - TotalLen: 20 + len(udp) + len(payload), - TTL: 64, - Protocol: 17, // UDP - Dst: serverAddr.IP, - Src: clientAddr.IP, - } - ret, err := h.Marshal() - if err != nil { - return nil, err - } - ret = append(ret, udp...) - ret = append(ret, payload...) - return ret, nil -} - -// makeRawSocket creates a socket that can be passed to unix.Sendto. -func makeRawSocket(ifname string) (int, error) { - fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) - if err != nil { - return fd, err - } - err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) - if err != nil { - return fd, err - } - err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1) - if err != nil { - return fd, err - } - err = dhcpv4.BindToInterface(fd, ifname) - if err != nil { - return fd, err - } - return fd, nil -} - -// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto -// that will send packets out to the broadcast address. -func MakeBroadcastSocket(ifname string) (int, error) { - fd, err := makeRawSocket(ifname) - if err != nil { - return fd, err - } - err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) - if err != nil { - return fd, err - } - return fd, nil -} - -// MakeListeningSocket creates a listening socket on 0.0.0.0 for the DHCP client -// port and returns it. -func MakeListeningSocket(ifname string) (int, error) { - return makeListeningSocketWithCustomPort(ifname, dhcpv4.ClientPort) -} - -func htons(v uint16) uint16 { - var tmp [2]byte - binary.BigEndian.PutUint16(tmp[:], v) - return binary.LittleEndian.Uint16(tmp[:]) -} - -func makeListeningSocketWithCustomPort(ifname string, port int) (int, error) { - fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM, int(htons(unix.ETH_P_IP))) - if err != nil { - return fd, err - } - iface, err := net.InterfaceByName(ifname) - if err != nil { - return fd, err - } - llAddr := unix.SockaddrLinklayer{ - Ifindex: iface.Index, - Protocol: htons(unix.ETH_P_IP), - } - err = unix.Bind(fd, &llAddr) - return fd, err -} - -func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) { - var uaddr *net.UDPAddr - if addr == nil { - uaddr = defaultAddr - } else { - if addr, ok := addr.(*net.UDPAddr); ok { - uaddr = addr - } else { - return nil, fmt.Errorf("could not convert to net.UDPAddr, got %v instead", reflect.TypeOf(addr)) - } - } - if uaddr.IP.To4() == nil { - return nil, fmt.Errorf("'%s' is not a valid IPv4 address", uaddr.IP) - } - return uaddr, nil -} - -func (c *Client) getLocalUDPAddr() (*net.UDPAddr, error) { - defaultLocalAddr := &net.UDPAddr{IP: net.IPv4zero, Port: dhcpv4.ClientPort} - laddr, err := toUDPAddr(c.LocalAddr, defaultLocalAddr) - if err != nil { - return nil, fmt.Errorf("Invalid local address: %s", err) - } - return laddr, nil -} - -func (c *Client) getRemoteUDPAddr() (*net.UDPAddr, error) { - defaultRemoteAddr := &net.UDPAddr{IP: net.IPv4bcast, Port: dhcpv4.ServerPort} - raddr, err := toUDPAddr(c.RemoteAddr, defaultRemoteAddr) - if err != nil { - return nil, fmt.Errorf("Invalid remote address: %s", err) - } - return raddr, nil -} - -// Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge, -// over UDP. Does not retry in case of failures. Returns a list of DHCPv4 -// structures representing the exchange. It can contain up to four elements, -// ordered as Discovery, Offer, Request and Acknowledge. In case of errors, an -// error is returned, and the list of DHCPv4 objects will be shorted than 4, -// containing all the sent and received DHCPv4 messages. -func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv4.DHCPv4, error) { - conversation := make([]*dhcpv4.DHCPv4, 0) - raddr, err := c.getRemoteUDPAddr() - if err != nil { - return nil, err - } - laddr, err := c.getLocalUDPAddr() - if err != nil { - return nil, err - } - // Get our file descriptor for the raw socket we need. - var sfd int - // If the address is not net.IPV4bcast, use a unicast socket. This should - // cover the majority of use cases, but we're essentially ignoring the fact - // that the IP could be the broadcast address of a specific subnet. - if raddr.IP.Equal(net.IPv4bcast) { - sfd, err = MakeBroadcastSocket(ifname) - } else { - sfd, err = makeRawSocket(ifname) - } - if err != nil { - return conversation, err - } - rfd, err := makeListeningSocketWithCustomPort(ifname, laddr.Port) - if err != nil { - return conversation, err - } - - defer func() { - // close the sockets - if err := unix.Close(sfd); err != nil { - log.Printf("unix.Close(sendFd) failed: %v", err) - } - if sfd != rfd { - if err := unix.Close(rfd); err != nil { - log.Printf("unix.Close(recvFd) failed: %v", err) - } - } - }() - - // Discover - discover, err := dhcpv4.NewDiscoveryForInterface(ifname, modifiers...) - if err != nil { - return conversation, err - } - conversation = append(conversation, discover) - - // Offer - offer, err := c.SendReceive(sfd, rfd, discover, dhcpv4.MessageTypeOffer) - if err != nil { - return conversation, err - } - conversation = append(conversation, offer) - - // Request - request, err := dhcpv4.NewRequestFromOffer(offer, modifiers...) - if err != nil { - return conversation, err - } - conversation = append(conversation, request) - - // Ack - ack, err := c.SendReceive(sfd, rfd, request, dhcpv4.MessageTypeAck) - if err != nil { - return conversation, err - } - conversation = append(conversation, ack) - - return conversation, nil -} - -// SendReceive sends a packet (with some write timeout) and waits for a -// response up to some read timeout value. If the message type is not -// MessageTypeNone, it will wait for a specific message type -func (c *Client) SendReceive(sendFd, recvFd int, packet *dhcpv4.DHCPv4, messageType dhcpv4.MessageType) (*dhcpv4.DHCPv4, error) { - raddr, err := c.getRemoteUDPAddr() - if err != nil { - return nil, err - } - laddr, err := c.getLocalUDPAddr() - if err != nil { - return nil, err - } - packetBytes, err := MakeRawUDPPacket(packet.ToBytes(), *raddr, *laddr) - if err != nil { - return nil, err - } - - // Create a goroutine to perform the blocking send, and time it out after - // a certain amount of time. - var ( - destination [net.IPv4len]byte - response *dhcpv4.DHCPv4 - ) - copy(destination[:], raddr.IP.To4()) - remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination} - recvErrors := make(chan error, 1) - go func(errs chan<- error) { - // set read timeout - timeout := unix.NsecToTimeval(c.ReadTimeout.Nanoseconds()) - if innerErr := unix.SetsockoptTimeval(recvFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout); innerErr != nil { - errs <- innerErr - return - } - for { - buf := make([]byte, MaxUDPReceivedPacketSize) - n, _, innerErr := unix.Recvfrom(recvFd, buf, 0) - if innerErr != nil { - errs <- innerErr - return - } - - var iph ipv4.Header - if err := iph.Parse(buf[:n]); err != nil { - // skip non-IP data - continue - } - if iph.Protocol != 17 { - // skip non-UDP packets - continue - } - udph := buf[iph.Len:n] - // check source and destination ports - srcPort := int(binary.BigEndian.Uint16(udph[0:2])) - expectedSrcPort := dhcpv4.ServerPort - if c.RemoteAddr != nil { - expectedSrcPort = c.RemoteAddr.(*net.UDPAddr).Port - } - if srcPort != expectedSrcPort { - continue - } - dstPort := int(binary.BigEndian.Uint16(udph[2:4])) - expectedDstPort := dhcpv4.ClientPort - if c.LocalAddr != nil { - expectedDstPort = c.LocalAddr.(*net.UDPAddr).Port - } - if dstPort != expectedDstPort { - continue - } - // UDP checksum is not checked - pLen := int(binary.BigEndian.Uint16(udph[4:6])) - payload := buf[iph.Len+8 : iph.Len+pLen] - - response, innerErr = dhcpv4.FromBytes(payload) - if innerErr != nil { - errs <- innerErr - return - } - // check that this is a response to our message - if response.TransactionID != packet.TransactionID { - continue - } - // wait for a response message - if response.OpCode != dhcpv4.OpcodeBootReply { - continue - } - // if we are not requested to wait for a specific message type, - // return what we have - if messageType == dhcpv4.MessageTypeNone { - break - } - // break if it's a reply of the desired type, continue otherwise - if response.MessageType() == messageType { - break - } - } - recvErrors <- nil - }(recvErrors) - - // send the request while the goroutine waits for replies - if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil { - return nil, err - } - - select { - case err = <-recvErrors: - if err == unix.EAGAIN { - return nil, errors.New("timed out while listening for replies") - } - if err != nil { - return nil, err - } - case <-time.After(c.ReadTimeout): - return nil, errors.New("timed out while listening for replies") - } - - return response, nil -} diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go deleted file mode 100644 index 8a7fbff6..00000000 --- a/dhcpv6/async/client.go +++ /dev/null @@ -1,236 +0,0 @@ -package async - -import ( - "context" - "fmt" - "log" - "net" - "sync" - "time" - - promise "github.com/fanliao/go-promise" - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" -) - -// Client implements an asynchronous DHCPv6 client -type Client struct { - ReadTimeout time.Duration - WriteTimeout time.Duration - LocalAddr net.Addr - RemoteAddr net.Addr - IgnoreErrors bool - - connection *net.UDPConn - cancel context.CancelFunc - stopping *sync.WaitGroup - receiveQueue chan dhcpv6.DHCPv6 - sendQueue chan dhcpv6.DHCPv6 - packetsLock sync.Mutex - packets map[dhcpv6.TransactionID]*promise.Promise - errors chan error -} - -// NewClient creates an asynchronous client -func NewClient() *Client { - return &Client{ - ReadTimeout: client6.DefaultReadTimeout, - WriteTimeout: client6.DefaultWriteTimeout, - } -} - -// OpenForInterface starts the client on the specified interface, replacing -// client LocalAddr with a link-local address of the given interface and -// standard DHCP port (546). -func (c *Client) OpenForInterface(ifname string, bufferSize int) error { - addr, err := dhcpv6.GetLinkLocalAddr(ifname) - if err != nil { - return err - } - c.LocalAddr = &net.UDPAddr{IP: addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} - return c.Open(bufferSize) -} - -// Open starts the client -func (c *Client) Open(bufferSize int) error { - var ( - addr *net.UDPAddr - ok bool - err error - ) - - if addr, ok = c.LocalAddr.(*net.UDPAddr); !ok { - return fmt.Errorf("Invalid local address: %v not a net.UDPAddr", c.LocalAddr) - } - - // prepare the socket to listen on for replies - c.connection, err = net.ListenUDP("udp6", addr) - if err != nil { - return err - } - c.stopping = new(sync.WaitGroup) - c.sendQueue = make(chan dhcpv6.DHCPv6, bufferSize) - c.receiveQueue = make(chan dhcpv6.DHCPv6, bufferSize) - c.packets = make(map[dhcpv6.TransactionID]*promise.Promise) - c.packetsLock = sync.Mutex{} - c.errors = make(chan error) - - var ctx context.Context - ctx, c.cancel = context.WithCancel(context.Background()) - go c.receiverLoop(ctx) - go c.senderLoop(ctx) - - return nil -} - -// Close stops the client -func (c *Client) Close() { - // Wait for sender and receiver loops - c.stopping.Add(2) - c.cancel() - c.stopping.Wait() - - close(c.sendQueue) - close(c.receiveQueue) - close(c.errors) - - c.connection.Close() -} - -// Errors returns a channel where runtime errors are posted -func (c *Client) Errors() <-chan error { - return c.errors -} - -func (c *Client) addError(err error) { - if !c.IgnoreErrors { - c.errors <- err - } -} - -func (c *Client) receiverLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.receiveQueue: - c.receive(packet) - } - } -} - -func (c *Client) senderLoop(ctx context.Context) { - defer func() { c.stopping.Done() }() - for { - select { - case <-ctx.Done(): - return - case packet := <-c.sendQueue: - c.send(packet) - } - } -} - -func (c *Client) send(packet dhcpv6.DHCPv6) { - transactionID, err := dhcpv6.GetTransactionID(packet) - if err != nil { - c.addError(fmt.Errorf("Warning: This should never happen, there is no transaction ID on %s", packet)) - return - } - c.packetsLock.Lock() - p := c.packets[transactionID] - c.packetsLock.Unlock() - - raddr, err := c.remoteAddr() - if err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot get remote address :%v", err) - return - } - - if err := c.connection.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot set write deadline :%v", err) - return - } - _, err = c.connection.WriteTo(packet.ToBytes(), raddr) - if err != nil { - _ = p.Reject(err) - log.Printf("Warning: cannot write to %s :%v", raddr, err) - return - } - - c.receiveQueue <- packet -} - -func (c *Client) receive(_ dhcpv6.DHCPv6) { - var ( - oobdata = []byte{} - received dhcpv6.DHCPv6 - ) - - if err := c.connection.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - log.Printf("Warning: cannot set read deadline :%v", err) - } - for { - buffer := make([]byte, client6.MaxUDPReceivedPacketSize) - n, _, _, _, err := c.connection.ReadMsgUDP(buffer, oobdata) - if err != nil { - if err, ok := err.(net.Error); !ok || !err.Timeout() { - c.addError(fmt.Errorf("Error receiving the message: %s", err)) - } - return - } - received, err = dhcpv6.FromBytes(buffer[:n]) - if err != nil { - // skip non-DHCP packets - continue - } - break - } - - transactionID, err := dhcpv6.GetTransactionID(received) - if err != nil { - c.addError(fmt.Errorf("Unable to get a transactionID for %s: %s", received, err)) - return - } - - c.packetsLock.Lock() - if p, ok := c.packets[transactionID]; ok { - delete(c.packets, transactionID) - _ = p.Resolve(received) - } - c.packetsLock.Unlock() -} - -func (c *Client) remoteAddr() (*net.UDPAddr, error) { - if c.RemoteAddr == nil { - return &net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort}, nil - } - - if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { - return addr, nil - } - return nil, fmt.Errorf("Invalid remote address: %v not a net.UDPAddr", c.RemoteAddr) -} - -// Send inserts a message to the queue to be sent asynchronously. -// Returns a future which resolves to response and error. -func (c *Client) Send(message dhcpv6.DHCPv6, modifiers ...dhcpv6.Modifier) *promise.Future { - for _, mod := range modifiers { - mod(message) - } - - transactionID, err := dhcpv6.GetTransactionID(message) - if err != nil { - return promise.Wrap(err) - } - - p := promise.NewPromise() - c.packetsLock.Lock() - c.packets[transactionID] = p - c.packetsLock.Unlock() - c.sendQueue <- message - return p.Future -} diff --git a/dhcpv6/async/client_test.go b/dhcpv6/async/client_test.go deleted file mode 100644 index 14b8026e..00000000 --- a/dhcpv6/async/client_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package async - -import ( - "context" - "net" - "testing" - "time" - - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" - "github.com/stretchr/testify/require" -) - -const retries = 5 - -// solicit creates new solicit based on the mac address -func solicit(input string) (*dhcpv6.Message, error) { - mac, err := net.ParseMAC(input) - if err != nil { - return nil, err - } - return dhcpv6.NewSolicit(mac) -} - -// server creates a server which responds with a predefined response -func serve(ctx context.Context, addr *net.UDPAddr, response dhcpv6.DHCPv6) error { - conn, err := net.ListenUDP("udp6", addr) - if err != nil { - return err - } - go func() { - defer conn.Close() - oobdata := []byte{} - buffer := make([]byte, client6.MaxUDPReceivedPacketSize) - for { - select { - case <-ctx.Done(): - return - default: - if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - n, _, _, src, err := conn.ReadMsgUDP(buffer, oobdata) - if err != nil { - continue - } - _, err = dhcpv6.FromBytes(buffer[:n]) - if err != nil { - continue - } - if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { - panic(err) - } - _, err = conn.WriteTo(response.ToBytes(), src) - if err != nil { - continue - } - } - } - }() - return nil -} - -func TestNewClient(t *testing.T) { - c := NewClient() - require.NotNil(t, c) - require.Equal(t, c.ReadTimeout, client6.DefaultReadTimeout) - require.Equal(t, c.ReadTimeout, client6.DefaultWriteTimeout) -} - -func TestOpenInvalidAddrFailes(t *testing.T) { - c := NewClient() - err := c.Open(512) - require.Error(t, err) -} - -// This test uses port 15438 so please make sure its not used before running -func TestOpenClose(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - c.LocalAddr = addr - err = c.Open(512) - require.NoError(t, err) - defer c.Close() -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSendTimeout(t *testing.T) { - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp6", ":15439") - require.NoError(t, err) - c.ReadTimeout = 50 * time.Millisecond - c.WriteTimeout = 50 * time.Millisecond - c.LocalAddr = addr - c.RemoteAddr = remote - err = c.Open(512) - require.NoError(t, err) - defer c.Close() - m, err := dhcpv6.NewMessage() - require.NoError(t, err) - _, err, timeout := c.Send(m).GetOrTimeout(200) - require.NoError(t, err) - require.True(t, timeout) -} - -// This test uses ports 15438 and 15439 so please make sure they are not used -// before running -func TestSend(t *testing.T) { - s, err := solicit("c8:6c:2c:47:96:fd") - require.NoError(t, err) - require.NotNil(t, s) - - a, err := dhcpv6.NewAdvertiseFromSolicit(s) - require.NoError(t, err) - require.NotNil(t, a) - - c := NewClient() - addr, err := net.ResolveUDPAddr("udp6", ":15438") - require.NoError(t, err) - remote, err := net.ResolveUDPAddr("udp6", ":15439") - require.NoError(t, err) - c.LocalAddr = addr - c.RemoteAddr = remote - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err = serve(ctx, remote, a) - require.NoError(t, err) - - err = c.Open(16) - require.NoError(t, err) - defer c.Close() - - f := c.Send(s) - - var passed bool - for i := 0; i < retries; i++ { - response, err, timeout := f.GetOrTimeout(1000) - if timeout { - continue - } - passed = true - require.NoError(t, err) - require.Equal(t, a, response) - } - require.True(t, passed, "All attempts to TestSend timed out") -} diff --git a/dhcpv6/client6/client.go b/dhcpv6/client6/client.go deleted file mode 100644 index e10b61f1..00000000 --- a/dhcpv6/client6/client.go +++ /dev/null @@ -1,237 +0,0 @@ -package client6 - -import ( - "errors" - "fmt" - "net" - "time" - - "github.com/insomniacslk/dhcp/dhcpv6" -) - -// Client constants -const ( - DefaultWriteTimeout = 3 * time.Second // time to wait for write calls - DefaultReadTimeout = 3 * time.Second // time to wait for read calls - DefaultInterfaceUpTimeout = 3 * time.Second // time to wait before a network interface goes up - MaxUDPReceivedPacketSize = 8192 // arbitrary size. Theoretically could be up to 65kb -) - -// Client implements a DHCPv6 client -type Client struct { - ReadTimeout time.Duration - WriteTimeout time.Duration - LocalAddr net.Addr - RemoteAddr net.Addr - SimulateRelay bool - RelayOptions dhcpv6.Options // These options will be added to relay message if SimulateRelay is true -} - -// NewClient returns a Client with default settings -func NewClient() *Client { - return &Client{ - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, - } -} - -// Exchange executes a 4-way DHCPv6 request (Solicit, Advertise, Request, -// Reply). The modifiers will be applied to the Solicit and Request packets. -// A common use is to make sure that the Solicit packet has the right options, -// see modifiers.go -func (c *Client) Exchange(ifname string, modifiers ...dhcpv6.Modifier) ([]dhcpv6.DHCPv6, error) { - conversation := make([]dhcpv6.DHCPv6, 0) - var err error - - // Solicit - solicit, advertise, err := c.Solicit(ifname, modifiers...) - if solicit != nil { - conversation = append(conversation, solicit) - } - if err != nil { - return conversation, err - } - conversation = append(conversation, advertise) - - // Decapsulate advertise if it's relayed before passing it to Request - if advertise.IsRelay() { - advertiseRelay := advertise.(*dhcpv6.RelayMessage) - advertise, err = advertiseRelay.GetInnerMessage() - if err != nil { - return conversation, err - } - } - request, reply, err := c.Request(ifname, advertise.(*dhcpv6.Message), modifiers...) - if request != nil { - conversation = append(conversation, request) - } - if err != nil { - return conversation, err - } - conversation = append(conversation, reply) - return conversation, nil -} - -func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType dhcpv6.MessageType) (dhcpv6.DHCPv6, error) { - if packet == nil { - return nil, fmt.Errorf("Packet to send cannot be nil") - } - // if no LocalAddr is specified, get the interface's link-local address - var laddr net.UDPAddr - if c.LocalAddr == nil { - llAddr, err := dhcpv6.GetLinkLocalAddr(ifname) - if err != nil { - return nil, err - } - laddr = net.UDPAddr{IP: llAddr, Port: dhcpv6.DefaultClientPort, Zone: ifname} - } else { - if addr, ok := c.LocalAddr.(*net.UDPAddr); ok { - laddr = *addr - } else { - return nil, fmt.Errorf("Invalid local address: not a net.UDPAddr: %v", c.LocalAddr) - } - } - if c.SimulateRelay { - var err error - packet, err = dhcpv6.EncapsulateRelay(packet, dhcpv6.MessageTypeRelayForward, net.IPv6zero, laddr.IP) - if err != nil { - return nil, err - } - // Add Relay Options to ecapsulated Packet - for _, opt := range c.RelayOptions { - packet.UpdateOption(opt) - } - } - if expectedType == dhcpv6.MessageTypeNone { - // infer the expected type from the packet being sent - if packet.Type() == dhcpv6.MessageTypeSolicit { - expectedType = dhcpv6.MessageTypeAdvertise - } else if packet.Type() == dhcpv6.MessageTypeRequest { - expectedType = dhcpv6.MessageTypeReply - } else if packet.Type() == dhcpv6.MessageTypeRelayForward { - expectedType = dhcpv6.MessageTypeRelayReply - } else if packet.Type() == dhcpv6.MessageTypeLeaseQuery { - expectedType = dhcpv6.MessageTypeLeaseQueryReply - } // and probably more - } - - // if no RemoteAddr is specified, use AllDHCPRelayAgentsAndServers - var raddr net.UDPAddr - if c.RemoteAddr == nil { - raddr = net.UDPAddr{IP: dhcpv6.AllDHCPRelayAgentsAndServers, Port: dhcpv6.DefaultServerPort} - } else { - if addr, ok := c.RemoteAddr.(*net.UDPAddr); ok { - raddr = *addr - } else { - return nil, fmt.Errorf("Invalid remote address: not a net.UDPAddr: %v", c.RemoteAddr) - } - } - - // prepare the socket to listen on for replies - conn, err := net.ListenUDP("udp6", &laddr) - if err != nil { - return nil, err - } - defer conn.Close() - // wait for the listener to be ready, fail if it takes too much time - deadline := time.Now().Add(time.Second) - for { - if now := time.Now(); now.After(deadline) { - return nil, errors.New("Timed out waiting for listener to be ready") - } - if conn.LocalAddr() != nil { - break - } - time.Sleep(10 * time.Millisecond) - } - - // send the packet out - if err := conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil { - return nil, err - } - _, err = conn.WriteTo(packet.ToBytes(), &raddr) - if err != nil { - return nil, err - } - - // wait for a reply - oobdata := []byte{} // ignoring oob data - if err := conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { - return nil, err - } - var ( - adv dhcpv6.DHCPv6 - isMessage bool - ) - defer conn.Close() - msg, ok := packet.(*dhcpv6.Message) - if ok { - isMessage = true - } - for { - buf := make([]byte, MaxUDPReceivedPacketSize) - n, _, _, _, err := conn.ReadMsgUDP(buf, oobdata) - if err != nil { - return nil, err - } - adv, err = dhcpv6.FromBytes(buf[:n]) - if err != nil { - // skip non-DHCP packets - // - // TODO: It also skips DHCP packets with any errors (for example - // if bootfile params are encoded incorrectly). We need to - // log such cases instead of silently skip them. - continue - } - if recvMsg, ok := adv.(*dhcpv6.Message); ok && isMessage { - // if a regular message, check the transaction ID first - // XXX should this unpack relay messages and check the XID of the - // inner packet too? - if msg.TransactionID != recvMsg.TransactionID { - // different XID, we don't want this packet for sure - continue - } - } - if expectedType == dhcpv6.MessageTypeNone { - // just take whatever arrived - break - } else if adv.Type() == expectedType { - break - } - } - return adv, nil -} - -// Solicit sends a Solicit, returns the Solicit, an Advertise (if not nil), and -// an error if any. The modifiers will be applied to the Solicit before sending -// it, see modifiers.go -func (c *Client) Solicit(ifname string, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) { - iface, err := net.InterfaceByName(ifname) - if err != nil { - return nil, nil, err - } - solicit, err := dhcpv6.NewSolicit(iface.HardwareAddr) - if err != nil { - return nil, nil, err - } - for _, mod := range modifiers { - mod(solicit) - } - advertise, err := c.sendReceive(ifname, solicit, dhcpv6.MessageTypeNone) - return solicit, advertise, err -} - -// Request sends a Request built from an Advertise. It returns the Request, a -// Reply (if not nil), and an error if any. The modifiers will be applied to -// the Request before sending it, see modifiers.go -func (c *Client) Request(ifname string, advertise *dhcpv6.Message, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) { - request, err := dhcpv6.NewRequestFromAdvertise(advertise) - if err != nil { - return nil, nil, err - } - for _, mod := range modifiers { - mod(request) - } - reply, err := c.sendReceive(ifname, request, dhcpv6.MessageTypeNone) - return request, reply, err -} diff --git a/dhcpv6/client6/client_test.go b/dhcpv6/client6/client_test.go deleted file mode 100644 index 1e05a629..00000000 --- a/dhcpv6/client6/client_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package client6 - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestNewClient(t *testing.T) { - c := NewClient() - require.NotNil(t, c) - require.Equal(t, DefaultReadTimeout, c.ReadTimeout) - require.Equal(t, DefaultWriteTimeout, c.WriteTimeout) -} diff --git a/examples/client6/README.md b/examples/client6/README.md deleted file mode 100644 index 5fa7bfad..00000000 --- a/examples/client6/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# DHCPv6 client - -A minimal DHCPv6 client can be implemented in a few lines of code, by using the default -client parameters. The example in [main.go](./main.go) lets you specify the -interface to send packets through, and defaults to `eth0`. diff --git a/examples/client6/main.go b/examples/client6/main.go deleted file mode 100644 index 41cf6927..00000000 --- a/examples/client6/main.go +++ /dev/null @@ -1,42 +0,0 @@ -package main - -import ( - "flag" - "log" - - "github.com/insomniacslk/dhcp/dhcpv6/client6" -) - -var ( - iface = flag.String("i", "eth0", "Interface to configure via DHCPv6") -) - -func main() { - flag.Parse() - log.Printf("Starting DHCPv6 client on interface %s", *iface) - - // NewClient sets up a new DHCPv6 client with default values - // for read and write timeouts, for destination address and listening - // address - client := client6.NewClient() - - // Exchange runs a Solicit-Advertise-Request-Reply transaction on the - // specified network interface, and returns a list of DHCPv6 packets - // (a "conversation") and an error if any. Notice that Exchange may - // return a non-empty packet list even if there is an error. This is - // intended, because the transaction may fail at any point, and we - // still want to know what packets were exchanged until then. - // A default Solicit packet will be used during the "conversation", - // which can be manipulated by using modifiers. - conversation, err := client.Exchange(*iface) - - // Summary() prints a verbose representation of the exchanged packets. - for _, packet := range conversation { - log.Print(packet.Summary()) - } - // error handling is done *after* printing, so we still print the - // exchanged packets if any, as explained above. - if err != nil { - log.Fatal(err) - } -} diff --git a/netboot/netboot.go b/netboot/netboot.go deleted file mode 100644 index b32f69e4..00000000 --- a/netboot/netboot.go +++ /dev/null @@ -1,162 +0,0 @@ -package netboot - -import ( - "errors" - "fmt" - "log" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/client4" - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" -) - -var sleeper = func(d time.Duration) { - time.Sleep(d) -} - -// BootConf is a structure describes everything a host needs to know to boot over network -type BootConf struct { - // NetConf is the network configuration of the client - NetConf - - // BootfileURL is "where is the image (kernel)". - // See RFC5970 section 3.1 for IPv6 and RFC2132 section 9.5 ("Bootfile name") for IPv4 - BootfileURL string - - // BootfileParam is "what arguments should we pass (cmdline)". - // See RFC5970 section 3.2 for IPv6. - BootfileParam []string -} - -// RequestNetbootv6 sends a netboot request via DHCPv6 and returns the exchanged packets. Additional modifiers -// can be passed to manipulate both solicit and advertise packets. -func RequestNetbootv6(ifname string, timeout time.Duration, retries int, modifiers ...dhcpv6.Modifier) ([]dhcpv6.DHCPv6, error) { - var ( - conversation []dhcpv6.DHCPv6 - err error - ) - modifiers = append(modifiers, dhcpv6.WithNetboot) - delay := 2 * time.Second - for i := 0; i <= retries; i++ { - log.Printf("sending request, attempt #%d", i+1) - - client := client6.NewClient() - client.ReadTimeout = timeout - conversation, err = client.Exchange(ifname, modifiers...) - if err != nil { - log.Printf("Client.Exchange failed: %v", err) - if i >= retries { - // don't wait at the end of the last attempt - return nil, fmt.Errorf("netboot failed after %d attempts: %v", retries+1, err) - } - log.Printf("sleeping %v before retrying", delay) - sleeper(delay) - // TODO add random splay - delay = delay * 2 - continue - } - break - } - return conversation, nil -} - -// RequestNetbootv4 sends a netboot request via DHCPv4 and returns the exchanged packets. Additional modifiers -// can be passed to manipulate both the discover and offer packets. -func RequestNetbootv4(ifname string, timeout time.Duration, retries int, modifiers ...dhcpv4.Modifier) ([]*dhcpv4.DHCPv4, error) { - var ( - conversation []*dhcpv4.DHCPv4 - err error - ) - delay := 2 * time.Second - modifiers = append(modifiers, dhcpv4.WithNetboot) - for i := 0; i <= retries; i++ { - log.Printf("sending request, attempt #%d", i+1) - client := client4.NewClient() - client.ReadTimeout = timeout - conversation, err = client.Exchange(ifname, modifiers...) - if err != nil { - log.Printf("Client.Exchange failed: %v", err) - log.Printf("sleeping %v before retrying", delay) - if i >= retries { - // don't wait at the end of the last attempt - break - } - sleeper(delay) - // TODO add random splay - delay = delay * 2 - continue - } - break - } - return conversation, nil -} - -// ConversationToNetconf extracts network configuration and boot file URL from a -// DHCPv6 4-way conversation and returns them, or an error if any. -func ConversationToNetconf(conversation []dhcpv6.DHCPv6) (*BootConf, error) { - var advertise, reply *dhcpv6.Message - for _, m := range conversation { - switch m.Type() { - case dhcpv6.MessageTypeAdvertise: - advertise = m.(*dhcpv6.Message) - case dhcpv6.MessageTypeReply: - reply = m.(*dhcpv6.Message) - } - } - if reply == nil { - return nil, errors.New("no REPLY received") - } - - bootconf := &BootConf{} - netconf, err := GetNetConfFromPacketv6(reply) - if err != nil { - return nil, fmt.Errorf("cannot get netconf from packet: %v", err) - } - bootconf.NetConf = *netconf - - if u := reply.Options.BootFileURL(); len(u) > 0 { - bootconf.BootfileURL = u - bootconf.BootfileParam = reply.Options.BootFileParam() - } else { - log.Printf("no bootfile URL option found in REPLY, fallback to ADVERTISE's value") - if u := advertise.Options.BootFileURL(); len(u) > 0 { - bootconf.BootfileURL = u - bootconf.BootfileParam = advertise.Options.BootFileParam() - } - } - if len(bootconf.BootfileURL) == 0 { - return nil, errors.New("no bootfile URL option found") - } - return bootconf, nil -} - -// ConversationToNetconfv4 extracts network configuration and boot file URL from a -// DHCPv4 4-way conversation and returns them, or an error if any. -func ConversationToNetconfv4(conversation []*dhcpv4.DHCPv4) (*BootConf, error) { - var reply *dhcpv4.DHCPv4 - for _, m := range conversation { - // look for a BootReply packet of type Offer containing the bootfile URL. - // Normally both packets with Message Type OFFER or ACK do contain - // the bootfile URL. - if m.OpCode == dhcpv4.OpcodeBootReply && m.MessageType() == dhcpv4.MessageTypeOffer { - reply = m - break - } - } - if reply == nil { - return nil, errors.New("no OFFER with valid bootfile URL received") - } - - bootconf := &BootConf{} - netconf, err := GetNetConfFromPacketv4(reply) - if err != nil { - return nil, fmt.Errorf("could not get netconf: %v", err) - } - bootconf.NetConf = *netconf - - bootconf.BootfileURL = reply.BootFileName - // TODO: should we support bootfile parameters here somehow? (see netconf.BootfileParam) - return bootconf, nil -} diff --git a/netboot/netconf.go b/netboot/netconf.go deleted file mode 100644 index 698c0d19..00000000 --- a/netboot/netconf.go +++ /dev/null @@ -1,253 +0,0 @@ -package netboot - -import ( - "errors" - "fmt" - "io/ioutil" - "net" - "os" - "strings" - "syscall" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/jsimonetti/rtnetlink" - "github.com/jsimonetti/rtnetlink/rtnl" - "github.com/mdlayher/netlink" -) - -// AddrConf holds a single IP address configuration for a NIC -type AddrConf struct { - IPNet net.IPNet - PreferredLifetime time.Duration - ValidLifetime time.Duration -} - -// NetConf holds multiple IP configuration for a NIC, and DNS configuration -type NetConf struct { - Addresses []AddrConf - DNSServers []net.IP - DNSSearchList []string - Routers []net.IP - NTPServers []net.IP -} - -// GetNetConfFromPacketv6 extracts network configuration information from a DHCPv6 -// Reply packet and returns a populated NetConf structure -func GetNetConfFromPacketv6(d *dhcpv6.Message) (*NetConf, error) { - iana := d.Options.OneIANA() - if iana == nil { - return nil, errors.New("no option IA NA found") - } - netconf := NetConf{} - - for _, iaaddr := range iana.Options.Addresses() { - netconf.Addresses = append(netconf.Addresses, AddrConf{ - IPNet: net.IPNet{ - IP: iaaddr.IPv6Addr, - - // This mask tells Linux which addresses we know to be - // "on-link" (i.e., reachable on this interface without - // having to talk to a router). - // - // Since DHCPv6 does not give us that information, we - // have to assume that no addresses are on-link. To do - // that, we use /128. (See also RFC 5942 Section 5, - // "Observed Incorrect Implementation Behavior".) - Mask: net.CIDRMask(128, 128), - }, - PreferredLifetime: iaaddr.PreferredLifetime, - ValidLifetime: iaaddr.ValidLifetime, - }) - } - // get DNS configuration - netconf.DNSServers = d.Options.DNS() - - // get domain search list - domains := d.Options.DomainSearchList() - if domains != nil { - netconf.DNSSearchList = domains.Labels - } - - // get NTP servers - netconf.NTPServers = d.Options.NTPServers() - - return &netconf, nil -} - -// GetNetConfFromPacketv4 extracts network configuration information from a DHCPv4 -// Reply packet and returns a populated NetConf structure -func GetNetConfFromPacketv4(d *dhcpv4.DHCPv4) (*NetConf, error) { - // extract the address from the DHCPv4 address - ipAddr := d.YourIPAddr - if ipAddr == nil || ipAddr.Equal(net.IPv4zero) { - return nil, errors.New("ip address is null (0.0.0.0)") - } - netconf := NetConf{} - - // get the subnet mask from OptionSubnetMask. If the netmask is not defined - // in the packet, an error is returned - netmask := d.SubnetMask() - if netmask == nil { - return nil, errors.New("no netmask option in response packet") - } - ones, _ := netmask.Size() - if ones == 0 { - return nil, errors.New("netmask extracted from OptSubnetMask options is null") - } - - // netconf struct requires a valid lifetime to be specified. ValidLifetime is a dhcpv6 - // concept, the closest mapping in dhcpv4 world is "IP Address Lease Time". If the lease - // time option is nil, we set it to 0 - leaseTime := d.IPAddressLeaseTime(0) - - netconf.Addresses = append(netconf.Addresses, AddrConf{ - IPNet: net.IPNet{ - IP: ipAddr, - Mask: netmask, - }, - PreferredLifetime: 0, - ValidLifetime: leaseTime, - }) - - // get DNS configuration - netconf.DNSServers = d.DNS() - - // get domain search list - dnsSearchList := d.DomainSearch() - if dnsSearchList != nil { - if len(dnsSearchList.Labels) == 0 { - return nil, errors.New("dns search list is empty") - } - netconf.DNSSearchList = dnsSearchList.Labels - } - - // get default gateway - routersList := d.Router() - if len(routersList) == 0 { - return nil, errors.New("no routers specified in the corresponding option") - } - netconf.Routers = routersList - - // get NTP servers - netconf.NTPServers = d.NTPServers() - - return &netconf, nil -} - -// IfUp brings up an interface by name, and waits for it to come up until a timeout expires -func IfUp(ifname string, timeout time.Duration) (_ *net.Interface, err error) { - start := time.Now() - rt, err := rtnl.Dial(nil) - if err != nil { - return nil, err - } - defer func() { - if cerr := rt.Close(); cerr != nil { - err = cerr - } - }() - for time.Since(start) < timeout { - iface, err := net.InterfaceByName(ifname) - if err != nil { - return nil, err - } - // If the interface is up, return. According to kernel documentation OperState may - // be either Up or Unknown: - // Interface is in RFC2863 operational state UP or UNKNOWN. This is for - // backward compatibility, routing daemons, dhcp clients can use this - // flag to determine whether they should use the interface. - // Source: https://www.kernel.org/doc/Documentation/networking/operstates.txt - operState, err := getOperState(iface.Index) - if err != nil { - return nil, err - } - if operState == rtnetlink.OperStateUp || operState == rtnetlink.OperStateUnknown { - // XXX despite the OperUp state, upon the first attempt I - // consistently get a "cannot assign requested address" error. Need - // to investigate more. - time.Sleep(time.Second) - return iface, nil - } - // otherwise try to bring it up - if err := rt.LinkUp(iface); err != nil { - return nil, fmt.Errorf("interface %q: %v can't bring it up: %v", ifname, iface, err) - } - time.Sleep(10 * time.Millisecond) - } - - return nil, fmt.Errorf("timed out while waiting for %s to come up", ifname) - -} - -// ConfigureInterface configures a network interface with the configuration held by a -// NetConf structure -func ConfigureInterface(ifname string, netconf *NetConf) (err error) { - iface, err := net.InterfaceByName(ifname) - if err != nil { - return err - } - rt, err := rtnl.Dial(nil) - if err != nil { - return err - } - defer func() { - if cerr := rt.Close(); err != nil { - err = cerr - } - }() - // configure interfaces - for _, addr := range netconf.Addresses { - if err := rt.AddrAdd(iface, &addr.IPNet); err != nil { - return fmt.Errorf("cannot configure %s on %s: %v", ifname, addr.IPNet, err) - } - } - // configure /etc/resolv.conf - resolvconf := "" - for _, ns := range netconf.DNSServers { - resolvconf += fmt.Sprintf("nameserver %s\n", ns) - } - if len(netconf.DNSSearchList) > 0 { - resolvconf += fmt.Sprintf("search %s\n", strings.Join(netconf.DNSSearchList, " ")) - } - if err = ioutil.WriteFile("/etc/resolv.conf", []byte(resolvconf), 0644); err != nil { - return fmt.Errorf("could not write resolv.conf file %v", err) - } - - // FIXME wut? No IPv6 here? - // add default route information for v4 space. only one default route is allowed - // so ignore the others if there are multiple ones - if len(netconf.Routers) > 0 { - // if there is a default v4 route, remove it, as we want to add the one we just got during - // the dhcp transaction. if the route is not present, which is the final state we want, - // an error is returned so ignore it - dst := net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(0, 32), - } - // Remove a possible default route (dst 0.0.0.0) to the L2 domain (gw: 0.0.0.0), which is what - // a client would want to add before initiating the DHCP transaction in order not to fail with - // ENETUNREACH. If this default route has a specific metric assigned, it doesn't get removed. - // The code doesn't remove any other default route (i.e. gw != 0.0.0.0). - if err := rt.RouteDel(iface, net.IPNet{IP: net.IPv4zero}); err != nil { - switch err := err.(type) { - case *netlink.OpError: - // ignore the error if it's -EEXIST or -ESRCH - if !os.IsExist(err.Err) && err.Err != syscall.ESRCH { - return fmt.Errorf("could not delete default route on interface %s: %v", ifname, err) - } - default: - return fmt.Errorf("could not delete default route on interface %s: %v", ifname, err) - } - } - - src := netconf.Addresses[0].IPNet - // TODO handle the remaining Routers if more than one - if err := rt.RouteAdd(iface, dst, netconf.Routers[0], rtnl.WithRouteSrc(&src)); err != nil { - return fmt.Errorf("could not add gateway %s for src %s dst %s to interface %s: %v", netconf.Routers[0], src, dst, ifname, err) - } - } - - return nil -} diff --git a/netboot/netconf_integ_test.go b/netboot/netconf_integ_test.go deleted file mode 100644 index 099296c1..00000000 --- a/netboot/netconf_integ_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// +build integration - -package netboot - -import ( - "fmt" - "io/ioutil" - "log" - "net" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// The test assumes that the interface exists and is configurable. -// If you are running this test locally, you may need to adjust this value. -var ifname = "eth0" - -func TestIfUp(t *testing.T) { - iface, err := IfUp(ifname, 2*time.Second) - require.NoError(t, err) - assert.Equal(t, ifname, iface.Name) -} - -func TestIfUpTimeout(t *testing.T) { - _, err := IfUp(ifname, 0*time.Second) - require.Error(t, err) -} - -func TestConfigureInterface(t *testing.T) { - // Linux-only. `netboot.ConfigureInterface` writes to /etc/resolv.conf when - // `NetConf.DNSServers` is set. In this test we make a backup of resolv.conf - // and subsequently restore it. This is really ugly, and not safe if - // multiple tests do the same. - resolvconf, err := ioutil.ReadFile("/etc/resolv.conf") - if err != nil { - panic(fmt.Sprintf("Failed to read /etc/resolv.conf: %v", err)) - } - type testCase struct { - Name string - NetConf *NetConf - } - testCases := []testCase{ - { - Name: "just IP addr", - NetConf: &NetConf{ - Addresses: []AddrConf{ - AddrConf{IPNet: net.IPNet{IP: net.ParseIP("10.20.30.40")}}, - }, - }, - }, - { - Name: "IP addr, DNS, and routers", - NetConf: &NetConf{ - Addresses: []AddrConf{ - AddrConf{IPNet: net.IPNet{IP: net.ParseIP("10.20.30.40")}}, - }, - DNSServers: []net.IP{net.ParseIP("8.8.8.8")}, - DNSSearchList: []string{"slackware.it"}, - Routers: []net.IP{net.ParseIP("10.20.30.254")}, - }, - }, - } - for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - require.NoError(t, ConfigureInterface(ifname, tc.NetConf)) - - // after the test, restore the content of /etc/resolv.conf . The permissions - // are used only if it didn't exist. - if err = ioutil.WriteFile("/etc/resolv.conf", resolvconf, 0644); err != nil { - panic(fmt.Sprintf("Failed to restore /etc/resolv.conf: %v", err)) - } - log.Printf("Restored /etc/resolv.conf") - }) - } -} diff --git a/netboot/netconf_test.go b/netboot/netconf_test.go deleted file mode 100644 index e2541cdd..00000000 --- a/netboot/netconf_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package netboot - -import ( - "log" - "net" - "testing" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/stretchr/testify/require" -) - -func getAdv(advModifiers ...dhcpv6.Modifier) *dhcpv6.Message { - hwaddr, err := net.ParseMAC("aa:bb:cc:dd:ee:ff") - if err != nil { - log.Panic(err) - } - - sol, err := dhcpv6.NewSolicit(hwaddr) - if err != nil { - log.Panic(err) - } - d, err := dhcpv6.NewAdvertiseFromSolicit(sol, advModifiers...) - if err != nil { - log.Panic(err) - } - return d -} - -func TestGetNetConfFromPacketv6Invalid(t *testing.T) { - adv := getAdv() - _, err := GetNetConfFromPacketv6(adv) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv6NoSearchList(t *testing.T) { - addrs := []dhcpv6.OptIAAddress{ - dhcpv6.OptIAAddress{ - IPv6Addr: net.ParseIP("::1"), - PreferredLifetime: 3600 * time.Second, - ValidLifetime: 5200 * time.Second, - }, - } - adv := getAdv( - dhcpv6.WithIANA(addrs...), - dhcpv6.WithDNS(net.ParseIP("fe80::1")), - ) - _, err := GetNetConfFromPacketv6(adv) - require.NoError(t, err) -} - -func TestGetNetConfFromPacketv6(t *testing.T) { - addrs := []dhcpv6.OptIAAddress{ - dhcpv6.OptIAAddress{ - IPv6Addr: net.ParseIP("::1"), - PreferredLifetime: 3600 * time.Second, - ValidLifetime: 5200 * time.Second, - }, - } - adv := getAdv( - dhcpv6.WithIANA(addrs...), - dhcpv6.WithDNS(net.ParseIP("fe80::1")), - dhcpv6.WithDomainSearchList("slackware.it"), - ) - netconf, err := GetNetConfFromPacketv6(adv) - require.NoError(t, err) - // check addresses - require.Equal(t, 1, len(netconf.Addresses)) - require.Equal(t, net.ParseIP("::1"), netconf.Addresses[0].IPNet.IP) - require.Equal(t, 3600*time.Second, netconf.Addresses[0].PreferredLifetime) - require.Equal(t, 5200*time.Second, netconf.Addresses[0].ValidLifetime) - // check DNSes - require.Equal(t, 1, len(netconf.DNSServers)) - require.Equal(t, net.ParseIP("fe80::1"), netconf.DNSServers[0]) - // check DNS search list - require.Equal(t, 1, len(netconf.DNSSearchList)) - require.Equal(t, "slackware.it", netconf.DNSSearchList[0]) - // check routers - require.Equal(t, 0, len(netconf.Routers)) -} - -func TestGetNetConfFromPacketv4AddrZero(t *testing.T) { - d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.IPv4zero)) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4NoMask(t *testing.T) { - d, _ := dhcpv4.New(dhcpv4.WithYourIP(net.ParseIP("10.0.0.1"))) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4NullMask(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(0, 0, 0, 0)), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4NoLeaseTime(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4EmptyDNSList(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(0)), - dhcpv4.WithDNS(), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4NoSearchList(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(0)), - dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4EmptySearchList(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(0)), - dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), - dhcpv4.WithDomainSearchList(), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4NoRouter(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(0)), - dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), - dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4EmptyRouter(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(0)), - dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), - dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), - dhcpv4.WithRouter(), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - _, err := GetNetConfFromPacketv4(d) - require.Error(t, err) -} - -func TestGetNetConfFromPacketv4(t *testing.T) { - d, _ := dhcpv4.New( - dhcpv4.WithNetmask(net.IPv4Mask(255, 255, 255, 0)), - dhcpv4.WithLeaseTime(uint32(5200)), - dhcpv4.WithDNS(net.ParseIP("10.10.0.1"), net.ParseIP("10.10.0.2")), - dhcpv4.WithDomainSearchList("slackware.it", "dhcp.slackware.it"), - dhcpv4.WithRouter(net.ParseIP("10.0.0.254")), - dhcpv4.WithYourIP(net.ParseIP("10.0.0.1")), - ) - - netconf, err := GetNetConfFromPacketv4(d) - require.NoError(t, err) - // check addresses - require.Equal(t, 1, len(netconf.Addresses)) - require.Equal(t, net.ParseIP("10.0.0.1"), netconf.Addresses[0].IPNet.IP) - require.Equal(t, time.Duration(0), netconf.Addresses[0].PreferredLifetime) - require.Equal(t, 5200*time.Second, netconf.Addresses[0].ValidLifetime) - // check DNSes - require.Equal(t, 2, len(netconf.DNSServers)) - require.Equal(t, net.ParseIP("10.10.0.1").To4(), netconf.DNSServers[0]) - require.Equal(t, net.ParseIP("10.10.0.2").To4(), netconf.DNSServers[1]) - // check DNS search list - require.Equal(t, 2, len(netconf.DNSSearchList)) - require.Equal(t, "slackware.it", netconf.DNSSearchList[0]) - require.Equal(t, "dhcp.slackware.it", netconf.DNSSearchList[1]) - // check routers - require.Equal(t, 1, len(netconf.Routers)) - require.Equal(t, net.ParseIP("10.0.0.254").To4(), netconf.Routers[0]) -} diff --git a/netboot/rtnetlink_linux.go b/netboot/rtnetlink_linux.go deleted file mode 100644 index f6886f8d..00000000 --- a/netboot/rtnetlink_linux.go +++ /dev/null @@ -1,27 +0,0 @@ -package netboot - -import ( - "log" - - "github.com/jsimonetti/rtnetlink" -) - -// getOperState returns the operational state for the given interface index. -func getOperState(iface int) (rtnetlink.OperationalState, error) { - conn, err := rtnetlink.Dial(nil) - if err != nil { - return 0, err - } - defer func() { - err := conn.Close() - if err != nil { - log.Printf("failed to close rtnetlink connection: %v", err) - } - }() - - msg, err := conn.Link.Get(uint32(iface)) - if err != nil { - return 0, err - } - return msg.Attributes.OperationalState, nil -}