summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4/client.go')
-rw-r--r--dhcpv4/client.go85
1 files changed, 61 insertions, 24 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index 0e379ba..fbdc280 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -4,6 +4,7 @@ import (
"encoding/binary"
"errors"
"net"
+ "os"
"syscall"
"time"
@@ -94,6 +95,29 @@ func MakeBroadcastSocket(ifname string) (int, error) {
return fd, nil
}
+// MakeListeningSocket creates a listening socket on 0.0.0.0 for the DHCP client
+// port and returns it.
+func MakeListeningSocket(ifname string) (int, error) {
+ fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
+ if err != nil {
+ return fd, err
+ }
+ err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+ if err != nil {
+ return fd, err
+ }
+ var addr [4]byte
+ copy(addr[:], net.IPv4zero.To4())
+ if err = syscall.Bind(fd, &syscall.SockaddrInet4{Port: ClientPort, Addr: addr}); err != nil {
+ return fd, err
+ }
+ err = BindToInterface(fd, ifname)
+ if err != nil {
+ return fd, err
+ }
+ return fd, nil
+}
+
// Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge,
// over UDP. Does not retry in case of failures. Returns a list of DHCPv4
// structures representing the exchange. It can contain up to four elements,
@@ -105,7 +129,11 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
var err error
// Get our file descriptor for the broadcast socket.
- fd, err := MakeBroadcastSocket(ifname)
+ sfd, err := MakeBroadcastSocket(ifname)
+ if err != nil {
+ return conversation, err
+ }
+ rfd, err := MakeListeningSocket(ifname)
if err != nil {
return conversation, err
}
@@ -120,7 +148,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
conversation[0] = *discover
// Offer
- offer, err := BroadcastSendReceive(fd, discover, c.ReadTimeout, c.WriteTimeout)
+ offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout)
if err != nil {
return conversation, err
}
@@ -134,7 +162,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
conversation = append(conversation, *request)
// Ack
- ack, err := BroadcastSendReceive(fd, discover, c.ReadTimeout, c.WriteTimeout)
+ ack, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout)
if err != nil {
return conversation, err
}
@@ -144,7 +172,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4) ([]DHCPv4, error) {
// BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a
// response up to some read timeout value.
-func BroadcastSendReceive(fd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) {
+func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration) (*DHCPv4, error) {
packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes())
if err != nil {
return nil, err
@@ -155,34 +183,43 @@ func BroadcastSendReceive(fd int, packet *DHCPv4, readTimeout, writeTimeout time
var destination [4]byte
copy(destination[:], net.IPv4bcast.To4())
remoteAddr := syscall.SockaddrInet4{Port: ClientPort, Addr: destination}
- sendErrChan := make(chan error, 1)
- go func() { sendErrChan <- syscall.Sendto(fd, packetBytes, 0, &remoteAddr) }()
+ recvErrors := make(chan error, 1)
+ var response *DHCPv4
+ go func(errs chan<- error) {
+ conn, err := net.FileConn(os.NewFile(uintptr(recvFd), ""))
+ if err != nil {
+ errs <- err
+ return
+ }
+ defer conn.Close()
+ conn.SetReadDeadline(time.Now().Add(readTimeout))
- select {
- case err = <-sendErrChan:
+ buf := make([]byte, MaxUDPReceivedPacketSize)
+ n, _, _, _, err := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{})
if err != nil {
- return nil, err
+ errs <- err
+ return
}
- case <-time.After(writeTimeout):
- return nil, errors.New("timed out while sending broadcast")
- }
- conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort})
- if err != nil {
+ response, err = FromBytes(buf[:n])
+ if err != nil {
+ errs <- err
+ return
+ }
+ recvErrors <- nil
+ }(recvErrors)
+ if err = syscall.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil {
return nil, err
}
- defer conn.Close()
- conn.SetReadDeadline(time.Now().Add(readTimeout))
- buf := make([]byte, MaxUDPReceivedPacketSize)
- n, _, _, _, err := conn.ReadMsgUDP(buf, []byte{})
- if err != nil {
- return nil, err
+ select {
+ case err = <-recvErrors:
+ if err != nil {
+ return nil, err
+ }
+ case <-time.After(readTimeout):
+ return nil, errors.New("timed out while listening for replies")
}
- response, err := FromBytes(buf[:n])
- if err != nil {
- return nil, err
- }
return response, nil
}