diff options
Diffstat (limited to 'dhcpv4/client4/client.go')
-rw-r--r-- | dhcpv4/client4/client.go | 370 |
1 files changed, 370 insertions, 0 deletions
diff --git a/dhcpv4/client4/client.go b/dhcpv4/client4/client.go new file mode 100644 index 0000000..fca9c09 --- /dev/null +++ b/dhcpv4/client4/client.go @@ -0,0 +1,370 @@ +package client4 + +import ( + "encoding/binary" + "errors" + "fmt" + "log" + "net" + "reflect" + "time" + + "github.com/insomniacslk/dhcp/dhcpv4" + "golang.org/x/net/ipv4" + "golang.org/x/sys/unix" +) + +// MaxUDPReceivedPacketSize is the (arbitrary) maximum UDP packet size supported +// by this library. Theoretically could be up to 65kb. +const ( + MaxUDPReceivedPacketSize = 8192 +) + +var ( + // DefaultReadTimeout is the time to wait after listening in which the + // exchange is considered failed. + DefaultReadTimeout = 3 * time.Second + + // DefaultWriteTimeout is the time to wait after sending in which the + // exchange is considered failed. + DefaultWriteTimeout = 3 * time.Second +) + +// Client is the object that actually performs the DHCP exchange. It currently +// only has read and write timeout values, plus (optional) local and remote +// addresses. +type Client struct { + ReadTimeout, WriteTimeout time.Duration + RemoteAddr net.Addr + LocalAddr net.Addr +} + +// NewClient generates a new client to perform a DHCP exchange with, setting the +// read and write timeout fields to defaults. +func NewClient() *Client { + return &Client{ + ReadTimeout: DefaultReadTimeout, + WriteTimeout: DefaultWriteTimeout, + } +} + +// MakeRawUDPPacket converts a payload (a serialized DHCPv4 packet) into a +// raw UDP packet for the specified serverAddr from the specified clientAddr. +func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { + udp := make([]byte, 8) + binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) + binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) + binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload))) + binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum + + h := ipv4.Header{ + Version: 4, + Len: 20, + TotalLen: 20 + len(udp) + len(payload), + TTL: 64, + Protocol: 17, // UDP + Dst: serverAddr.IP, + Src: clientAddr.IP, + } + ret, err := h.Marshal() + if err != nil { + return nil, err + } + ret = append(ret, udp...) + ret = append(ret, payload...) + return ret, nil +} + +// makeRawSocket creates a socket that can be passed to unix.Sendto. +func makeRawSocket(ifname string) (int, error) { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) + if err != nil { + return fd, err + } + err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + if err != nil { + return fd, err + } + err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_HDRINCL, 1) + if err != nil { + return fd, err + } + err = dhcpv4.BindToInterface(fd, ifname) + if err != nil { + return fd, err + } + return fd, nil +} + +// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto +// that will send packets out to the broadcast address. +func MakeBroadcastSocket(ifname string) (int, error) { + fd, err := makeRawSocket(ifname) + if err != nil { + return fd, err + } + err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) + if err != nil { + return fd, err + } + return fd, nil +} + +// MakeListeningSocket creates a listening socket on 0.0.0.0 for the DHCP client +// port and returns it. +func MakeListeningSocket(ifname string) (int, error) { + return makeListeningSocketWithCustomPort(ifname, ClientPort) +} + +func htons(v uint16) uint16 { + var tmp [2]byte + binary.BigEndian.PutUint16(tmp[:], v) + return binary.LittleEndian.Uint16(tmp[:]) +} + +func makeListeningSocketWithCustomPort(ifname string, port int) (int, error) { + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM, int(htons(unix.ETH_P_IP))) + if err != nil { + return fd, err + } + iface, err := net.InterfaceByName(ifname) + if err != nil { + return fd, err + } + llAddr := unix.SockaddrLinklayer{ + Ifindex: iface.Index, + Protocol: htons(unix.ETH_P_IP), + } + err = unix.Bind(fd, &llAddr) + return fd, err +} + +func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) { + var uaddr *net.UDPAddr + if addr == nil { + uaddr = defaultAddr + } else { + if addr, ok := addr.(*net.UDPAddr); ok { + uaddr = addr + } else { + return nil, fmt.Errorf("could not convert to net.UDPAddr, got %v instead", reflect.TypeOf(addr)) + } + } + if uaddr.IP.To4() == nil { + return nil, fmt.Errorf("'%s' is not a valid IPv4 address", uaddr.IP) + } + return uaddr, nil +} + +func (c *Client) getLocalUDPAddr() (*net.UDPAddr, error) { + defaultLocalAddr := &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort} + laddr, err := toUDPAddr(c.LocalAddr, defaultLocalAddr) + if err != nil { + return nil, fmt.Errorf("Invalid local address: %s", err) + } + return laddr, nil +} + +func (c *Client) getRemoteUDPAddr() (*net.UDPAddr, error) { + defaultRemoteAddr := &net.UDPAddr{IP: net.IPv4bcast, Port: ServerPort} + raddr, err := toUDPAddr(c.RemoteAddr, defaultRemoteAddr) + if err != nil { + return nil, fmt.Errorf("Invalid remote address: %s", err) + } + return raddr, nil +} + +// Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge, +// over UDP. Does not retry in case of failures. Returns a list of DHCPv4 +// structures representing the exchange. It can contain up to four elements, +// ordered as Discovery, Offer, Request and Acknowledge. In case of errors, an +// error is returned, and the list of DHCPv4 objects will be shorted than 4, +// containing all the sent and received DHCPv4 messages. +func (c *Client) Exchange(ifname string, modifiers ...dhcpv4.Modifier) ([]*dhcpv4.DHCPv4, error) { + conversation := make([]*dhcpv4.DHCPv4, 0) + raddr, err := c.getRemoteUDPAddr() + if err != nil { + return nil, err + } + laddr, err := c.getLocalUDPAddr() + if err != nil { + return nil, err + } + // Get our file descriptor for the raw socket we need. + var sfd int + // If the address is not net.IPV4bcast, use a unicast socket. This should + // cover the majority of use cases, but we're essentially ignoring the fact + // that the IP could be the broadcast address of a specific subnet. + if raddr.IP.Equal(net.IPv4bcast) { + sfd, err = MakeBroadcastSocket(ifname) + } else { + sfd, err = makeRawSocket(ifname) + } + if err != nil { + return conversation, err + } + rfd, err := makeListeningSocketWithCustomPort(ifname, laddr.Port) + if err != nil { + return conversation, err + } + + defer func() { + // close the sockets + if err := unix.Close(sfd); err != nil { + log.Printf("unix.Close(sendFd) failed: %v", err) + } + if sfd != rfd { + if err := unix.Close(rfd); err != nil { + log.Printf("unix.Close(recvFd) failed: %v", err) + } + } + }() + + // Discover + discover, err := dhcpv4.NewDiscoveryForInterface(ifname, modifiers...) + if err != nil { + return conversation, err + } + conversation = append(conversation, discover) + + // Offer + offer, err := c.SendReceive(sfd, rfd, discover, dhcpv4.MessageTypeOffer) + if err != nil { + return conversation, err + } + conversation = append(conversation, offer) + + // Request + request, err := dhcpv4.NewRequestFromOffer(offer, modifiers...) + if err != nil { + return conversation, err + } + conversation = append(conversation, request) + + // Ack + ack, err := c.SendReceive(sfd, rfd, request, dhcpv4.MessageTypeAck) + if err != nil { + return conversation, err + } + conversation = append(conversation, ack) + + return conversation, nil +} + +// SendReceive sends a packet (with some write timeout) and waits for a +// response up to some read timeout value. If the message type is not +// MessageTypeNone, it will wait for a specific message type +func (c *Client) SendReceive(sendFd, recvFd int, packet *dhcpv4.DHCPv4, messageType dhcpv4.MessageType) (*dhcpv4.DHCPv4, error) { + raddr, err := c.getRemoteUDPAddr() + if err != nil { + return nil, err + } + laddr, err := c.getLocalUDPAddr() + if err != nil { + return nil, err + } + packetBytes, err := MakeRawUDPPacket(packet.ToBytes(), *raddr, *laddr) + if err != nil { + return nil, err + } + + // Create a goroutine to perform the blocking send, and time it out after + // a certain amount of time. + var ( + destination [net.IPv4len]byte + response *dhcpv4.DHCPv4 + ) + copy(destination[:], raddr.IP.To4()) + remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination} + recvErrors := make(chan error, 1) + go func(errs chan<- error) { + // set read timeout + timeout := unix.NsecToTimeval(c.ReadTimeout.Nanoseconds()) + if innerErr := unix.SetsockoptTimeval(recvFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout); innerErr != nil { + errs <- innerErr + return + } + for { + buf := make([]byte, MaxUDPReceivedPacketSize) + n, _, innerErr := unix.Recvfrom(recvFd, buf, 0) + if innerErr != nil { + errs <- innerErr + return + } + + var iph ipv4.Header + if err := iph.Parse(buf[:n]); err != nil { + // skip non-IP data + continue + } + if iph.Protocol != 17 { + // skip non-UDP packets + continue + } + udph := buf[iph.Len:n] + // check source and destination ports + srcPort := int(binary.BigEndian.Uint16(udph[0:2])) + expectedSrcPort := ServerPort + if c.RemoteAddr != nil { + expectedSrcPort = c.RemoteAddr.(*net.UDPAddr).Port + } + if srcPort != expectedSrcPort { + continue + } + dstPort := int(binary.BigEndian.Uint16(udph[2:4])) + expectedDstPort := ClientPort + if c.RemoteAddr != nil { + expectedDstPort = c.LocalAddr.(*net.UDPAddr).Port + } + if dstPort != expectedDstPort { + continue + } + // UDP checksum is not checked + pLen := int(binary.BigEndian.Uint16(udph[4:6])) + payload := buf[iph.Len+8 : iph.Len+8+pLen] + + response, innerErr = dhcpv4.FromBytes(payload) + if innerErr != nil { + errs <- innerErr + return + } + // check that this is a response to our message + if response.TransactionID != packet.TransactionID { + continue + } + // wait for a response message + if response.OpCode != dhcpv4.OpcodeBootReply { + continue + } + // if we are not requested to wait for a specific message type, + // return what we have + if messageType == dhcpv4.MessageTypeNone { + break + } + // break if it's a reply of the desired type, continue otherwise + if response.MessageType() == messageType { + break + } + } + recvErrors <- nil + }(recvErrors) + + // send the request while the goroutine waits for replies + if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil { + return nil, err + } + + select { + case err = <-recvErrors: + if err == unix.EAGAIN { + return nil, errors.New("timed out while listening for replies") + } + if err != nil { + return nil, err + } + case <-time.After(c.ReadTimeout): + return nil, errors.New("timed out while listening for replies") + } + + return response, nil +} |