summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4/nclient4/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4/nclient4/client.go')
-rw-r--r--dhcpv4/nclient4/client.go412
1 files changed, 412 insertions, 0 deletions
diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go
new file mode 100644
index 0000000..3c97a60
--- /dev/null
+++ b/dhcpv4/nclient4/client.go
@@ -0,0 +1,412 @@
+// Copyright 2018 the u-root Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+// Package nclient4 is a small, minimum-functionality client for DHCPv4.
+//
+// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as
+// well as the Request-Ack renewal process.
+package nclient4
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+)
+
+const (
+ defaultTimeout = 5 * time.Second
+ defaultRetries = 3
+ defaultBufferCap = 5
+ maxMessageSize = 1500
+
+ // ClientPort is the port that DHCP clients listen on.
+ ClientPort = 68
+
+ // ServerPort is the port that DHCP servers and relay agents listen on.
+ ServerPort = 67
+)
+
+var (
+ // DefaultServers is the address of all link-local DHCP servers and
+ // relay agents.
+ DefaultServers = &net.UDPAddr{
+ IP: net.IPv4bcast,
+ Port: ServerPort,
+ }
+)
+
+var (
+ // ErrNoResponse is returned when no response packet is received.
+ ErrNoResponse = errors.New("no matching response packet received")
+)
+
+// pendingCh is a channel associated with a pending TransactionID.
+type pendingCh struct {
+ // SendAndRead closes done to indicate that it wishes for no more
+ // messages for this particular XID.
+ done <-chan struct{}
+
+ // ch is used by the receive loop to distribute DHCP messages.
+ ch chan<- *dhcpv4.DHCPv4
+}
+
+// Client is an IPv4 DHCP client.
+type Client struct {
+ ifaceHWAddr net.HardwareAddr
+ conn net.PacketConn
+ timeout time.Duration
+ retry int
+
+ // bufferCap is the channel capacity for each TransactionID.
+ bufferCap int
+
+ // serverAddr is the UDP address to send all packets to.
+ //
+ // This may be an actual broadcast address, or a unicast address.
+ serverAddr *net.UDPAddr
+
+ // closed is an atomic bool set to 1 when done is closed.
+ closed uint32
+
+ // done is closed to unblock the receive loop.
+ done chan struct{}
+
+ // wg protects any spawned goroutines, namely the receiveLoop.
+ wg sync.WaitGroup
+
+ pendingMu sync.Mutex
+ // pending stores the distribution channels for each pending
+ // TransactionID. receiveLoop uses this map to determine which channel
+ // to send a new DHCP message to.
+ pending map[dhcpv4.TransactionID]*pendingCh
+}
+
+// New returns a client usable with an unconfigured interface.
+func New(ifaceName string, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
+ c := NewWithConn(nil, ifaceHWAddr, opts...)
+
+ // Do this after so that a caller can still use a WithConn to override
+ // the connection.
+ if c.conn == nil {
+ pc, err := NewRawUDPConn(ifaceName, ClientPort)
+ if err != nil {
+ return nil, err
+ }
+ c.conn = pc
+ }
+ return c, nil
+}
+
+// NewWithConn creates a new DHCP client that sends and receives packets on the
+// given interface.
+func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) *Client {
+ c := &Client{
+ ifaceHWAddr: ifaceHWAddr,
+ timeout: defaultTimeout,
+ retry: defaultRetries,
+ serverAddr: DefaultServers,
+ bufferCap: defaultBufferCap,
+ conn: conn,
+
+ done: make(chan struct{}),
+ pending: make(map[dhcpv4.TransactionID]*pendingCh),
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ c.wg.Add(1)
+ go c.receiveLoop()
+ return c
+}
+
+// Close closes the underlying connection.
+func (c *Client) Close() error {
+ // Make sure not to close done twice.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return nil
+ }
+
+ err := c.conn.Close()
+
+ // Closing c.done sets off a chain reaction:
+ //
+ // Any SendAndRead unblocks trying to receive more messages, which
+ // means rem() gets called.
+ //
+ // rem() should be unblocking receiveLoop if it is blocked.
+ //
+ // receiveLoop should then exit gracefully.
+ close(c.done)
+
+ // Wait for receiveLoop to stop.
+ c.wg.Wait()
+
+ return err
+}
+
+func isErrClosing(err error) bool {
+ // Unfortunately, the epoll-connection-closed error is internal to the
+ // net library.
+ return strings.Contains(err.Error(), "use of closed network connection")
+}
+
+func (c *Client) receiveLoop() {
+ defer c.wg.Done()
+ for {
+ // TODO: Clients can send a "max packet size" option in their
+ // packets, IIRC. Choose a reasonable size and set it.
+ b := make([]byte, maxMessageSize)
+ n, _, err := c.conn.ReadFrom(b)
+ if err != nil {
+ if !isErrClosing(err) {
+ log.Printf("error reading from UDP connection: %v", err)
+ }
+ return
+ }
+
+ msg, err := dhcpv4.FromBytes(b[:n])
+ if err != nil {
+ // Not a valid DHCP packet; keep listening.
+ continue
+ }
+
+ if msg.OpCode != dhcpv4.OpcodeBootReply {
+ // Not a response message.
+ continue
+ }
+
+ // This is a somewhat non-standard check, by the looks
+ // of RFC 2131. It should work as long as the DHCP
+ // server is spec-compliant for the HWAddr field.
+ if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) {
+ // Not for us.
+ continue
+ }
+
+ c.pendingMu.Lock()
+ p, ok := c.pending[msg.TransactionID]
+ if ok {
+ select {
+ case <-p.done:
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+
+ // This send may block.
+ case p.ch <- msg:
+ }
+ }
+ c.pendingMu.Unlock()
+ }
+}
+
+// ClientOpt is a function that configures the Client.
+type ClientOpt func(*Client)
+
+// WithTimeout configures the retransmission timeout.
+//
+// Default is 5 seconds.
+func WithTimeout(d time.Duration) ClientOpt {
+ return func(c *Client) {
+ c.timeout = d
+ }
+}
+
+func withBufferCap(n int) ClientOpt {
+ return func(c *Client) {
+ c.bufferCap = n
+ }
+}
+
+// WithRetry configures the number of retransmissions to attempt.
+//
+// Default is 3.
+func WithRetry(r int) ClientOpt {
+ return func(c *Client) {
+ c.retry = r
+ }
+}
+
+// WithConn configures the packet connection to use.
+func WithConn(conn net.PacketConn) ClientOpt {
+ return func(c *Client) {
+ c.conn = conn
+ }
+}
+
+// WithServerAddr configures the address to send messages to.
+func WithServerAddr(n *net.UDPAddr) ClientOpt {
+ return func(c *Client) {
+ c.serverAddr = n
+ }
+}
+
+// Matcher matches DHCP packets.
+type Matcher func(*dhcpv4.DHCPv4) bool
+
+// IsMessageType returns a matcher that checks for the message type.
+//
+// If t is MessageTypeNone, all packets are matched.
+func IsMessageType(t dhcpv4.MessageType) Matcher {
+ return func(p *dhcpv4.DHCPv4) bool {
+ return p.MessageType() == t || t == dhcpv4.MessageTypeNone
+ }
+}
+
+// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
+// received.
+func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) {
+ // RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should
+ // contain.
+ discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ if err != nil {
+ return nil, err
+ }
+ return c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
+}
+
+// Request completes the 4-way Discover-Offer-Request-Ack handshake.
+//
+// Note that modifiers will be applied *both* to Discover and Request packets.
+func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) {
+ offer, err = c.DiscoverOffer(ctx, modifiers...)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // TODO(chrisko): should this be unicast to the server?
+ req, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ if err != nil {
+ return nil, nil, err
+ }
+ ack, err = c.SendAndRead(ctx, c.serverAddr, req, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ return offer, ack, nil
+}
+
+// send sends p to destination and returns a response channel.
+//
+// Responses will be matched by transaction ID and ClientHWAddr.
+//
+// The returned lambda function must be called after all desired responses have
+// been received in order to return the Transaction ID to the usable pool.
+func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) {
+ c.pendingMu.Lock()
+ if _, ok := c.pending[msg.TransactionID]; ok {
+ c.pendingMu.Unlock()
+ return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ }
+
+ ch := make(chan *dhcpv4.DHCPv4, c.bufferCap)
+ done := make(chan struct{})
+ c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
+ c.pendingMu.Unlock()
+
+ cancel = func() {
+ // Why can't we just close ch here?
+ //
+ // Because receiveLoop may potentially be blocked trying to
+ // send on ch. We gotta unblock it first, and then we can take
+ // the lock and remove the XID from the pending transaction
+ // map.
+ close(done)
+
+ c.pendingMu.Lock()
+ if p, ok := c.pending[msg.TransactionID]; ok {
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+ }
+ c.pendingMu.Unlock()
+ }
+
+ if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
+ cancel()
+ return nil, nil, fmt.Errorf("error writing packet to connection: %v", err)
+ }
+ return ch, cancel, nil
+}
+
+// This error should never be visible to users.
+// It is used only to increase the timeout in retryFn.
+var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
+
+// SendAndRead sends a packet p to a destination dest and waits for the first
+// response matching `match` as well as its Transaction ID and ClientHWAddr.
+//
+// If match is nil, the first packet matching the Transaction ID and
+// ClientHWAddr is returned.
+func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) {
+ var response *dhcpv4.DHCPv4
+ err := c.retryFn(func(timeout time.Duration) error {
+ ch, rem, err := c.send(dest, p)
+ if err != nil {
+ return err
+ }
+ defer rem()
+
+ for {
+ select {
+ case <-c.done:
+ return ErrNoResponse
+
+ case <-time.After(timeout):
+ return errDeadlineExceeded
+
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case packet := <-ch:
+ if match == nil || match(packet) {
+ response = packet
+ return nil
+ }
+ }
+ }
+ })
+ if err == errDeadlineExceeded {
+ return nil, ErrNoResponse
+ }
+ if err != nil {
+ return nil, err
+ }
+ return response, nil
+}
+
+func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
+ timeout := c.timeout
+
+ // Each retry takes the amount of timeout at worst.
+ for i := 0; i < c.retry || c.retry < 0; i++ {
+ switch err := fn(timeout); err {
+ case nil:
+ // Got it!
+ return nil
+
+ case errDeadlineExceeded:
+ // Double timeout, then retry.
+ timeout *= 2
+
+ default:
+ return err
+ }
+ }
+
+ return errDeadlineExceeded
+}