diff options
Diffstat (limited to 'dhcpv4/client.go')
-rw-r--r-- | dhcpv4/client.go | 85 |
1 files changed, 61 insertions, 24 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go index 0e379ba..fbdc280 100644 --- a/dhcpv4/client.go +++ b/dhcpv4/client.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "net" + "os" "syscall" "time" @@ -94,6 +95,29 @@ func MakeBroadcastSocket(ifname string) (int, error) { return fd, nil } +// MakeListeningSocket creates a listening socket on 0.0.0.0 for the DHCP client +// port and returns it. +func MakeListeningSocket(ifname string) (int, error) { + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + return fd, err + } + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + return fd, err + } + var addr [4]byte + copy(addr[:], net.IPv4zero.To4()) + if err = syscall.Bind(fd, &syscall.SockaddrInet4{Port: ClientPort, Addr: addr}); err != nil { + return fd, err + } + err = BindToInterface(fd, ifname) + if err != nil { + return fd, err + } + return fd, nil +} + // Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge, // over UDP. Does not retry in case of failures. Returns a list of DHCPv4 // structures representing the exchange. It can contain up to four elements, @@ -105,7 +129,11 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { var err error // Get our file descriptor for the broadcast socket. - fd, err := MakeBroadcastSocket(ifname) + sfd, err := MakeBroadcastSocket(ifname) + if err != nil { + return conversation, err + } + rfd, err := MakeListeningSocket(ifname) if err != nil { return conversation, err } @@ -120,7 +148,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { conversation[0] = *discover // Offer - offer, err := BroadcastSendReceive(fd, discover, c.ReadTimeout, c.WriteTimeout) + offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout) if err != nil { return conversation, err } @@ -134,7 +162,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { conversation = append(conversation, *request) // Ack - ack, err := BroadcastSendReceive(fd, discover, c.ReadTimeout, c.WriteTimeout) + ack, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout) if err != nil { return conversation, err } @@ -144,7 +172,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) { // BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a // response up to some read timeout value. -func BroadcastSendReceive(fd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) { +func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) { packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes()) if err != nil { return nil, err @@ -155,34 +183,43 @@ func BroadcastSendReceive(fd int, packet *DHCPv4, readTimeout, writeTimeout time var destination [4]byte copy(destination[:], net.IPv4bcast.To4()) remoteAddr := syscall.SockaddrInet4{Port: ClientPort, Addr: destination} - sendErrChan := make(chan error, 1) - go func() { sendErrChan <- syscall.Sendto(fd, packetBytes, 0, &remoteAddr) }() + recvErrors := make(chan error, 1) + var response *DHCPv4 + go func(errs chan<- error) { + conn, err := net.FileConn(os.NewFile(uintptr(recvFd), "")) + if err != nil { + errs <- err + return + } + defer conn.Close() + conn.SetReadDeadline(time.Now().Add(readTimeout)) - select { - case err = <-sendErrChan: + buf := make([]byte, MaxUDPReceivedPacketSize) + n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{}) if err != nil { - return nil, err + errs <- err + return } - case <-time.After(writeTimeout): - return nil, errors.New("timed out while sending broadcast") - } - conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort}) - if err != nil { + response, err = FromBytes(buf[:n]) + if err != nil { + errs <- err + return + } + recvErrors <- nil + }(recvErrors) + if err = syscall.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil { return nil, err } - defer conn.Close() - conn.SetReadDeadline(time.Now().Add(readTimeout)) - buf := make([]byte, MaxUDPReceivedPacketSize) - n, _, _, _, err := conn.ReadMsgUDP(buf, []byte{}) - if err != nil { - return nil, err + select { + case err = <-recvErrors: + if err != nil { + return nil, err + } + case <-time.After(readTimeout): + return nil, errors.New("timed out while listening for replies") } - response, err := FromBytes(buf[:n]) - if err != nil { - return nil, err - } return response, nil } |