diff options
Diffstat (limited to 'dhcpv4')
-rw-r--r-- | dhcpv4/client.go | 185 |
1 files changed, 168 insertions, 17 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go index e753b55..e6866ca 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -6,6 +6,9 @@ import ( "net" "os" "time" + "fmt" + "reflect" + "log" "golang.org/x/net/ipv4" "golang.org/x/sys/unix" @@ -28,9 +31,12 @@ var ( ) // Client is the object that actually performs the DHCP exchange. It currently -// only has read and write timeout values. +// only has read and write timeout values, plus (optional) local and remote +// addresses. type Client struct { ReadTimeout, WriteTimeout time.Duration + RemoteAddr net.Addr + LocalAddr net.Addr } // NewClient generates a new client to perform a DHCP exchange with, setting the @@ -42,12 +48,21 @@ func NewClient() *Client { } } -// MakeRawBroadcastPacket converts payload (a serialized DHCPv4 packet) into a -// raw packet suitable for UDP broadcast. +// MakeRawBroadcastPacket leverages MakeRawPacket to create a raw packet suitable +// for UDP broadcast. func MakeRawBroadcastPacket(payload []byte) ([]byte, error) { + log.Printf("Warning: dhcpv4.MakeRawBroadcastPacket() is deprecated and will be removed.") + serverAddr := net.UDPAddr{IP: net.IPv4bcast, Port: ServerPort} + clientAddr := net.UDPAddr{IP: net.IPv4zero, Port: ClientPort} + return MakeRawUDPPacket(payload, serverAddr, clientAddr) +} + +// MakeRawUDPPacket converts a payload (a serialized DHCPv4 packet) into a +// raw UDP packet for the specified serverAddr from the specified clientAddr. +func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { udp := make([]byte, 8) - binary.BigEndian.PutUint16(udp[:2], ClientPort) - binary.BigEndian.PutUint16(udp[2:4], ServerPort) + binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) + binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload))) binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum @@ -57,8 +72,8 @@ func MakeRawBroadcastPacket(payload []byte) ([]byte, error) { TotalLen: 20 + len(udp) + len(payload), TTL: 64, Protocol: 17, // UDP - Dst: net.IPv4bcast, - Src: net.IPv4zero, + Dst: serverAddr.IP, + Src: clientAddr.IP, } ret, err := h.Marshal() if err != nil { @@ -69,9 +84,8 @@ func MakeRawBroadcastPacket(payload []byte) ([]byte, error) { return ret, nil } -// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto -// that will send packets out to the broadcast address. -func MakeBroadcastSocket(ifname string) (int, error) { +// makeRawSocket creates a socket that can be passed to unix.Sendto. +func makeRawSocket() (int, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW) if err != nil { return fd, err @@ -84,11 +98,17 @@ func MakeBroadcastSocket(ifname string) (int, error) { if err != nil { return fd, err } - err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) + return fd, nil +} + +// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto +// that will send packets out to the broadcast address. +func MakeBroadcastSocket(ifname string) (int, error) { + fd, err := makeRawSocket() if err != nil { return fd, err } - err = BindToInterface(fd, ifname) + err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) if err != nil { return fd, err } @@ -118,6 +138,41 @@ func MakeListeningSocket(ifname string) (int, error) { return fd, nil } +func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) { + var uaddr *net.UDPAddr + if addr == nil { + uaddr = defaultAddr + } else { + if addr, ok := addr.(*net.UDPAddr); ok { + uaddr = addr + } else { + return nil, fmt.Errorf("could not convert to net.UDPAddr, got %v instead", reflect.TypeOf(addr)) + } + } + if uaddr.IP.To4() == nil { + return nil, fmt.Errorf("'%s' is not a valid IPv4 address", uaddr.IP) + } + return uaddr, nil +} + +func (c *Client) getLocalUDPAddr() (*net.UDPAddr, error) { + defaultLocalAddr := &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort} + laddr, err := toUDPAddr(c.LocalAddr, defaultLocalAddr) + if err != nil { + return nil, fmt.Errorf("Invalid local address: %s", err) + } + return laddr, nil +} + +func (c *Client) getRemoteUDPAddr() (*net.UDPAddr, error) { + defaultRemoteAddr := &net.UDPAddr{IP: net.IPv4bcast, Port: ServerPort} + raddr, err := toUDPAddr(c.RemoteAddr, defaultRemoteAddr) + if err != nil { + return nil, fmt.Errorf("Invalid remote address: %s", err) + } + return raddr, nil +} + // Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge, // over UDP. Does not retry in case of failures. Returns a list of DHCPv4 // structures representing the exchange. It can contain up to four elements, @@ -127,9 +182,20 @@ func MakeListeningSocket(ifname string) (int, error) { func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier) ([]*DHCPv4, error) { conversation := make([]*DHCPv4, 0) var err error - - // Get our file descriptor for the broadcast socket. - sfd, err := MakeBroadcastSocket(ifname) + raddr, err := c.getRemoteUDPAddr() + if err != nil { + return nil, err + } + // Get our file descriptor for the raw socket we need. + var sfd int + // If the address is not net.IPV4bcast, use a unicast socket. This should + // cover the majority of use cases, but we're essentially ignoring the fact + // that the IP could be the broadcast address of a specific subnet. + if raddr.IP.Equal(net.IPv4bcast) { + sfd, err = MakeBroadcastSocket(ifname) + } else { + sfd, err = makeRawSocket() + } if err != nil { return conversation, err } @@ -151,7 +217,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier conversation = append(conversation, discover) // Offer - offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout, MessageTypeOffer) + offer, err := c.sendReceive(sfd, rfd, discover, MessageTypeOffer) if err != nil { return conversation, err } @@ -165,7 +231,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier conversation = append(conversation, request) // Ack - ack, err := BroadcastSendReceive(sfd, rfd, request, c.ReadTimeout, c.WriteTimeout, MessageTypeAck) + ack, err := c.sendReceive(sfd, rfd, request, MessageTypeAck) if err != nil { return conversation, err } @@ -173,10 +239,95 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier return conversation, nil } +// sendReceive sends a packet (with some write timeout) and waits for a +// response up to some read timeout value. If the message type is not +// MessageTypeNone, it will wait for a specific message type +func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType MessageType) (*DHCPv4, error) { + raddr, err := c.getRemoteUDPAddr() + if err != nil { + return nil, err + } + laddr, err := c.getLocalUDPAddr() + if err != nil { + return nil, err + } + packetBytes, err := MakeRawUDPPacket(packet.ToBytes(), *raddr, *laddr) + if err != nil { + return nil, err + } + + // Create a goroutine to perform the blocking send, and time it out after + // a certain amount of time. + var ( + destination [4]byte + response *DHCPv4 + ) + copy(destination[:], raddr.IP.To4()) + remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination} + recvErrors := make(chan error, 1) + go func(errs chan<- error) { + conn, innerErr := net.FileConn(os.NewFile(uintptr(recvFd), "")) + if err != nil { + errs <- innerErr + return + } + defer conn.Close() + conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + + for { + buf := make([]byte, MaxUDPReceivedPacketSize) + n, _, _, _, innerErr := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{}) + if innerErr != nil { + errs <- innerErr + return + } + + response, innerErr = FromBytes(buf[:n]) + if err != nil { + errs <- innerErr + return + } + // check that this is a response to our message + if response.TransactionID() != packet.TransactionID() { + continue + } + // wait for a response message + if response.Opcode() != OpcodeBootReply { + continue + } + // if we are not requested to wait for a specific message type, + // return what we have + if messageType == MessageTypeNone { + break + } + // break if it's a reply of the desired type, continue otherwise + if response.MessageType() != nil && *response.MessageType() == messageType { + break + } + } + recvErrors <- nil + }(recvErrors) + if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil { + return nil, err + } + + select { + case err = <-recvErrors: + if err != nil { + return nil, err + } + case <-time.After(c.ReadTimeout): + return nil, errors.New("timed out while listening for replies") + } + + return response, nil +} + // BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a // response up to some read timeout value. If the message type is not // MessageTypeNone, it will wait for a specific message type func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration, messageType MessageType) (*DHCPv4, error) { + log.Printf("Warning: dhcpv4.BroadcastSendAndReceive() is deprecated and will be removed. You should use dhcpv4.client.Exchange() instead.") packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes()) if err != nil { return nil, err |