summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/client.go185
1 files changed, 168 insertions, 17 deletions
diff --git a/dhcpv4/client.go b/dhcpv4/client.go
index e753b55..e6866ca 100644
--- a/dhcpv4/client.go
+++ b/dhcpv4/client.go
@@ -6,6 +6,9 @@ import (
"net"
"os"
"time"
+ "fmt"
+ "reflect"
+ "log"
"golang.org/x/net/ipv4"
"golang.org/x/sys/unix"
@@ -28,9 +31,12 @@ var (
)
// Client is the object that actually performs the DHCP exchange. It currently
-// only has read and write timeout values.
+// only has read and write timeout values, plus (optional) local and remote
+// addresses.
type Client struct {
ReadTimeout, WriteTimeout time.Duration
+ RemoteAddr net.Addr
+ LocalAddr net.Addr
}
// NewClient generates a new client to perform a DHCP exchange with, setting the
@@ -42,12 +48,21 @@ func NewClient() *Client {
}
}
-// MakeRawBroadcastPacket converts payload (a serialized DHCPv4 packet) into a
-// raw packet suitable for UDP broadcast.
+// MakeRawBroadcastPacket leverages MakeRawPacket to create a raw packet suitable
+// for UDP broadcast.
func MakeRawBroadcastPacket(payload []byte) ([]byte, error) {
+ log.Printf("Warning: dhcpv4.MakeRawBroadcastPacket() is deprecated and will be removed.")
+ serverAddr := net.UDPAddr{IP: net.IPv4bcast, Port: ServerPort}
+ clientAddr := net.UDPAddr{IP: net.IPv4zero, Port: ClientPort}
+ return MakeRawUDPPacket(payload, serverAddr, clientAddr)
+}
+
+// MakeRawUDPPacket converts a payload (a serialized DHCPv4 packet) into a
+// raw UDP packet for the specified serverAddr from the specified clientAddr.
+func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) {
udp := make([]byte, 8)
- binary.BigEndian.PutUint16(udp[:2], ClientPort)
- binary.BigEndian.PutUint16(udp[2:4], ServerPort)
+ binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port))
+ binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port))
binary.BigEndian.PutUint16(udp[4:6], uint16(8+len(payload)))
binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum
@@ -57,8 +72,8 @@ func MakeRawBroadcastPacket(payload []byte) ([]byte, error) {
TotalLen: 20 + len(udp) + len(payload),
TTL: 64,
Protocol: 17, // UDP
- Dst: net.IPv4bcast,
- Src: net.IPv4zero,
+ Dst: serverAddr.IP,
+ Src: clientAddr.IP,
}
ret, err := h.Marshal()
if err != nil {
@@ -69,9 +84,8 @@ func MakeRawBroadcastPacket(payload []byte) ([]byte, error) {
return ret, nil
}
-// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto
-// that will send packets out to the broadcast address.
-func MakeBroadcastSocket(ifname string) (int, error) {
+// makeRawSocket creates a socket that can be passed to unix.Sendto.
+func makeRawSocket() (int, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW)
if err != nil {
return fd, err
@@ -84,11 +98,17 @@ func MakeBroadcastSocket(ifname string) (int, error) {
if err != nil {
return fd, err
}
- err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1)
+ return fd, nil
+}
+
+// MakeBroadcastSocket creates a socket that can be passed to unix.Sendto
+// that will send packets out to the broadcast address.
+func MakeBroadcastSocket(ifname string) (int, error) {
+ fd, err := makeRawSocket()
if err != nil {
return fd, err
}
- err = BindToInterface(fd, ifname)
+ err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1)
if err != nil {
return fd, err
}
@@ -118,6 +138,41 @@ func MakeListeningSocket(ifname string) (int, error) {
return fd, nil
}
+func toUDPAddr(addr net.Addr, defaultAddr *net.UDPAddr) (*net.UDPAddr, error) {
+ var uaddr *net.UDPAddr
+ if addr == nil {
+ uaddr = defaultAddr
+ } else {
+ if addr, ok := addr.(*net.UDPAddr); ok {
+ uaddr = addr
+ } else {
+ return nil, fmt.Errorf("could not convert to net.UDPAddr, got %v instead", reflect.TypeOf(addr))
+ }
+ }
+ if uaddr.IP.To4() == nil {
+ return nil, fmt.Errorf("'%s' is not a valid IPv4 address", uaddr.IP)
+ }
+ return uaddr, nil
+}
+
+func (c *Client) getLocalUDPAddr() (*net.UDPAddr, error) {
+ defaultLocalAddr := &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort}
+ laddr, err := toUDPAddr(c.LocalAddr, defaultLocalAddr)
+ if err != nil {
+ return nil, fmt.Errorf("Invalid local address: %s", err)
+ }
+ return laddr, nil
+}
+
+func (c *Client) getRemoteUDPAddr() (*net.UDPAddr, error) {
+ defaultRemoteAddr := &net.UDPAddr{IP: net.IPv4bcast, Port: ServerPort}
+ raddr, err := toUDPAddr(c.RemoteAddr, defaultRemoteAddr)
+ if err != nil {
+ return nil, fmt.Errorf("Invalid remote address: %s", err)
+ }
+ return raddr, nil
+}
+
// Exchange runs a full DORA transaction: Discover, Offer, Request, Acknowledge,
// over UDP. Does not retry in case of failures. Returns a list of DHCPv4
// structures representing the exchange. It can contain up to four elements,
@@ -127,9 +182,20 @@ func MakeListeningSocket(ifname string) (int, error) {
func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier) ([]*DHCPv4, error) {
conversation := make([]*DHCPv4, 0)
var err error
-
- // Get our file descriptor for the broadcast socket.
- sfd, err := MakeBroadcastSocket(ifname)
+ raddr, err := c.getRemoteUDPAddr()
+ if err != nil {
+ return nil, err
+ }
+ // Get our file descriptor for the raw socket we need.
+ var sfd int
+ // If the address is not net.IPV4bcast, use a unicast socket. This should
+ // cover the majority of use cases, but we're essentially ignoring the fact
+ // that the IP could be the broadcast address of a specific subnet.
+ if raddr.IP.Equal(net.IPv4bcast) {
+ sfd, err = MakeBroadcastSocket(ifname)
+ } else {
+ sfd, err = makeRawSocket()
+ }
if err != nil {
return conversation, err
}
@@ -151,7 +217,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
conversation = append(conversation, discover)
// Offer
- offer, err := BroadcastSendReceive(sfd, rfd, discover, c.ReadTimeout, c.WriteTimeout, MessageTypeOffer)
+ offer, err := c.sendReceive(sfd, rfd, discover, MessageTypeOffer)
if err != nil {
return conversation, err
}
@@ -165,7 +231,7 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
conversation = append(conversation, request)
// Ack
- ack, err := BroadcastSendReceive(sfd, rfd, request, c.ReadTimeout, c.WriteTimeout, MessageTypeAck)
+ ack, err := c.sendReceive(sfd, rfd, request, MessageTypeAck)
if err != nil {
return conversation, err
}
@@ -173,10 +239,95 @@ func (c *Client) Exchange(ifname string, discover *DHCPv4, modifiers ...Modifier
return conversation, nil
}
+// sendReceive sends a packet (with some write timeout) and waits for a
+// response up to some read timeout value. If the message type is not
+// MessageTypeNone, it will wait for a specific message type
+func (c *Client) sendReceive(sendFd, recvFd int, packet *DHCPv4, messageType MessageType) (*DHCPv4, error) {
+ raddr, err := c.getRemoteUDPAddr()
+ if err != nil {
+ return nil, err
+ }
+ laddr, err := c.getLocalUDPAddr()
+ if err != nil {
+ return nil, err
+ }
+ packetBytes, err := MakeRawUDPPacket(packet.ToBytes(), *raddr, *laddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create a goroutine to perform the blocking send, and time it out after
+ // a certain amount of time.
+ var (
+ destination [4]byte
+ response *DHCPv4
+ )
+ copy(destination[:], raddr.IP.To4())
+ remoteAddr := unix.SockaddrInet4{Port: laddr.Port, Addr: destination}
+ recvErrors := make(chan error, 1)
+ go func(errs chan<- error) {
+ conn, innerErr := net.FileConn(os.NewFile(uintptr(recvFd), ""))
+ if err != nil {
+ errs <- innerErr
+ return
+ }
+ defer conn.Close()
+ conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
+
+ for {
+ buf := make([]byte, MaxUDPReceivedPacketSize)
+ n, _, _, _, innerErr := conn.(*net.UDPConn).ReadMsgUDP(buf, []byte{})
+ if innerErr != nil {
+ errs <- innerErr
+ return
+ }
+
+ response, innerErr = FromBytes(buf[:n])
+ if err != nil {
+ errs <- innerErr
+ return
+ }
+ // check that this is a response to our message
+ if response.TransactionID() != packet.TransactionID() {
+ continue
+ }
+ // wait for a response message
+ if response.Opcode() != OpcodeBootReply {
+ continue
+ }
+ // if we are not requested to wait for a specific message type,
+ // return what we have
+ if messageType == MessageTypeNone {
+ break
+ }
+ // break if it's a reply of the desired type, continue otherwise
+ if response.MessageType() != nil && *response.MessageType() == messageType {
+ break
+ }
+ }
+ recvErrors <- nil
+ }(recvErrors)
+ if err = unix.Sendto(sendFd, packetBytes, 0, &remoteAddr); err != nil {
+ return nil, err
+ }
+
+ select {
+ case err = <-recvErrors:
+ if err != nil {
+ return nil, err
+ }
+ case <-time.After(c.ReadTimeout):
+ return nil, errors.New("timed out while listening for replies")
+ }
+
+ return response, nil
+}
+
// BroadcastSendReceive broadcasts packet (with some write timeout) and waits for a
// response up to some read timeout value. If the message type is not
// MessageTypeNone, it will wait for a specific message type
func BroadcastSendReceive(sendFd, recvFd int, packet *DHCPv4, readTimeout, writeTimeout time.Duration, messageType MessageType) (*DHCPv4, error) {
+ log.Printf("Warning: dhcpv4.BroadcastSendAndReceive() is deprecated and will be removed. You should use dhcpv4.client.Exchange() instead.")
packetBytes, err := MakeRawBroadcastPacket(packet.ToBytes())
if err != nil {
return nil, err