summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6/nclient6
diff options
context:
space:
mode:
authorChristopher Koch <chrisko@google.com>2019-03-04 12:47:50 -0800
committerinsomniac <insomniacslk@users.noreply.github.com>2019-04-04 12:49:08 +0100
commit175868d67987770d2d729186f7676e0b98f925df (patch)
tree2d56b1cb28fddd14ad935fa3c41dcc368b31e30c /dhcpv6/nclient6
parentb40bd52ae58aee37cec9ef81008e24488350c98f (diff)
client6: new async DHCPv6 client like #250.
- Race-condition-averse. - Supports multiple concurrent requests. - Tested. - Requires a fully compatible net.PacketConn. Signed-off-by: Christopher Koch <chrisko@google.com>
Diffstat (limited to 'dhcpv6/nclient6')
-rw-r--r--dhcpv6/nclient6/client.go371
-rw-r--r--dhcpv6/nclient6/client_test.go258
2 files changed, 629 insertions, 0 deletions
diff --git a/dhcpv6/nclient6/client.go b/dhcpv6/nclient6/client.go
new file mode 100644
index 0000000..dc5dd33
--- /dev/null
+++ b/dhcpv6/nclient6/client.go
@@ -0,0 +1,371 @@
+// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package nclient6
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv6"
+)
+
+// Broadcast destination IP addresses as defined by RFC 3315
+var (
+ AllDHCPRelayAgentsAndServers = &net.UDPAddr{
+ IP: net.ParseIP("ff02::1:2"),
+ Port: dhcpv6.DefaultServerPort,
+ }
+ AllDHCPServers = &net.UDPAddr{
+ IP: net.ParseIP("ff05::1:3"),
+ Port: dhcpv6.DefaultServerPort,
+ }
+)
+
+const (
+ maxUDPReceivedPacketSize = 8192
+ maxMessageSize = 1500
+)
+
+var (
+ // ErrNoResponse is returned when no response packet is received.
+ ErrNoResponse = errors.New("no matching response packet received")
+)
+
+// pendingCh is a channel associated with a pending TransactionID.
+type pendingCh struct {
+ // SendAndRead closes done to indicate that it wishes for no more
+ // messages for this particular XID.
+ done <-chan struct{}
+
+ // ch is used by the receive loop to distribute DHCP messages.
+ ch chan<- *dhcpv6.Message
+}
+
+// Client is a DHCPv6 client.
+type Client struct {
+ ifaceHWAddr net.HardwareAddr
+ conn net.PacketConn
+ timeout time.Duration
+ retry int
+
+ // bufferCap is the channel capacity for each TransactionID.
+ bufferCap int
+
+ // serverAddr is the UDP address to send all packets to.
+ //
+ // This may be an actual broadcast address, or a unicast address.
+ serverAddr *net.UDPAddr
+
+ // closed is an atomic bool set to 1 when done is closed.
+ closed uint32
+
+ // done is closed to unblock the receive loop.
+ done chan struct{}
+
+ // wg protects the receiveLoop.
+ wg sync.WaitGroup
+
+ pendingMu sync.Mutex
+ // pending stores the distribution channels for each pending
+ // TransactionID. receiveLoop uses this map to determine which channel
+ // to send a new DHCP message to.
+ pending map[dhcpv6.TransactionID]*pendingCh
+}
+
+// NewIPv6UDPConn returns a UDP connection bound to both the interface and port
+// given based on a IPv6 DGRAM socket.
+func NewIPv6UDPConn(iface string, port int) (net.PacketConn, error) {
+ return net.ListenUDP("udp6", &net.UDPAddr{
+ Port: port,
+ Zone: iface,
+ })
+}
+
+// New creates a new DHCP client that sends and receives packets on the given
+// interface.
+func New(ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
+ c := &Client{
+ ifaceHWAddr: ifaceHWAddr,
+ timeout: 5 * time.Second,
+ retry: 3,
+ serverAddr: AllDHCPServers,
+ bufferCap: 5,
+
+ done: make(chan struct{}),
+ pending: make(map[dhcpv6.TransactionID]*pendingCh),
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ if c.conn == nil {
+ return nil, fmt.Errorf("require a connection")
+ }
+
+ c.receiveLoop()
+ return c, nil
+}
+
+// Close closes the underlying connection.
+func (c *Client) Close() error {
+ // Make sure not to close done twice.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return nil
+ }
+
+ err := c.conn.Close()
+
+ // Closing c.done sets off a chain reaction:
+ //
+ // Any SendAndRead unblocks trying to receive more messages, which
+ // means rem() gets called.
+ //
+ // rem() should be unblocking receiveLoop if it is blocked.
+ //
+ // receiveLoop should then exit gracefully.
+ close(c.done)
+
+ // Wait for receiveLoop to stop.
+ c.wg.Wait()
+
+ return err
+}
+
+func isErrClosing(err error) bool {
+ // Unfortunately, the epoll-connection-closed error is internal to the
+ // net library.
+ return strings.Contains(err.Error(), "use of closed network connection")
+}
+
+func (c *Client) receiveLoop() {
+ c.wg.Add(1)
+ go func() {
+ defer c.wg.Done()
+ for {
+ // TODO: Clients can send a "max packet size" option in their
+ // packets, IIRC. Choose a reasonable size and set it.
+ b := make([]byte, 1500)
+ n, _, err := c.conn.ReadFrom(b)
+ if err != nil {
+ if !isErrClosing(err) {
+ log.Printf("error reading from UDP connection: %v", err)
+ }
+ return
+ }
+
+ msg, err := dhcpv6.MessageFromBytes(b[:n])
+ if err != nil {
+ // Not a valid DHCP packet; keep listening.
+ continue
+ }
+
+ c.pendingMu.Lock()
+ p, ok := c.pending[msg.TransactionID]
+ if ok {
+ select {
+ case <-p.done:
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+
+ // This send may block.
+ case p.ch <- msg:
+ }
+ }
+ c.pendingMu.Unlock()
+ }
+ }()
+}
+
+// ClientOpt is a function that configures the Client.
+type ClientOpt func(*Client)
+
+func withBufferCap(n int) ClientOpt {
+ return func(c *Client) {
+ c.bufferCap = n
+ }
+}
+
+// WithTimeout configures the retransmission timeout.
+//
+// Default is 5 seconds.
+func WithTimeout(d time.Duration) ClientOpt {
+ return func(c *Client) {
+ c.timeout = d
+ }
+}
+
+// WithRetry configures the number of retransmissions to attempt.
+//
+// Default is 3.
+func WithRetry(r int) ClientOpt {
+ return func(c *Client) {
+ c.retry = r
+ }
+}
+
+// WithConn configures the packet connection to use.
+func WithConn(conn net.PacketConn) ClientOpt {
+ return func(c *Client) {
+ c.conn = conn
+ }
+}
+
+// WithBroadcastAddr configures the address to broadcast to.
+func WithBroadcastAddr(n *net.UDPAddr) ClientOpt {
+ return func(c *Client) {
+ c.serverAddr = n
+ }
+}
+
+// Matcher matches DHCP packets.
+type Matcher func(*dhcpv6.Message) bool
+
+// IsMessageType returns a matcher that checks for the message type.
+//
+// If t is MessageTypeNone, all packets are matched.
+func IsMessageType(t dhcpv6.MessageType) Matcher {
+ return func(p *dhcpv6.Message) bool {
+ return p.MessageType == t || t == dhcpv6.MessageTypeNone
+ }
+}
+
+// Solicit sends a solicitation message and returns the first valid
+// advertisement received.
+func (c *Client) Solicit(ctx context.Context, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) {
+ solicit, err := dhcpv6.NewSolicit(c.ifaceHWAddr, modifiers...)
+ if err != nil {
+ return nil, err
+ }
+ msg, err := c.SendAndRead(ctx, c.serverAddr, solicit, IsMessageType(dhcpv6.MessageTypeAdvertise))
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+// Request requests an IP Assignment from peer given an advertise message.
+func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) {
+ request, err := dhcpv6.NewRequestFromAdvertise(advertise, modifiers...)
+ if err != nil {
+ return nil, err
+ }
+ return c.SendAndRead(ctx, c.serverAddr, request, nil)
+}
+
+// send sends p to destination and returns a response channel.
+//
+// The returned function must be called after all desired responses have been
+// received.
+//
+// Responses will be matched by transaction ID.
+func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) {
+ c.pendingMu.Lock()
+ if _, ok := c.pending[msg.TransactionID]; ok {
+ c.pendingMu.Unlock()
+ return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ }
+
+ ch := make(chan *dhcpv6.Message, c.bufferCap)
+ done := make(chan struct{})
+ c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
+ c.pendingMu.Unlock()
+
+ cancel := func() {
+ // Why can't we just close ch here?
+ //
+ // Because receiveLoop may potentially be blocked trying to
+ // send on ch. We gotta unblock it first, so it'll unlock the
+ // lock, and then we can take the lock and remove the XID from
+ // the pending transaction map.
+ close(done)
+
+ c.pendingMu.Lock()
+ if p, ok := c.pending[msg.TransactionID]; ok {
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+ }
+ c.pendingMu.Unlock()
+ }
+
+ if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
+ cancel()
+ return nil, nil, fmt.Errorf("error writing packet to connection: %v", err)
+ }
+ return ch, cancel, nil
+}
+
+// This should never be visible to a user.
+var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
+
+// SendAndRead sends a packet p to a destination dest and waits for the first
+// response matching `match` as well as its Transaction ID.
+//
+// If match is nil, the first packet matching the Transaction ID is returned.
+func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) {
+ var response *dhcpv6.Message
+ err := c.retryFn(func(timeout time.Duration) error {
+ ch, rem, err := c.send(dest, msg)
+ if err != nil {
+ return err
+ }
+ defer rem()
+
+ for {
+ select {
+ case <-c.done:
+ return ErrNoResponse
+
+ case <-time.After(timeout):
+ return errDeadlineExceeded
+
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case packet := <-ch:
+ if match == nil || match(packet) {
+ response = packet
+ return nil
+ }
+ }
+ }
+ })
+ if err == errDeadlineExceeded {
+ return nil, ErrNoResponse
+ }
+ if err != nil {
+ return nil, err
+ }
+ return response, nil
+}
+
+func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
+ timeout := c.timeout
+
+ // Each retry takes the amount of timeout at worst.
+ for i := 0; i < c.retry || c.retry < 0; i++ {
+ switch err := fn(timeout); err {
+ case nil:
+ // Got it!
+ return nil
+
+ case errDeadlineExceeded:
+ // Double timeout, then retry.
+ timeout *= 2
+
+ default:
+ return err
+ }
+ }
+
+ return errDeadlineExceeded
+}
diff --git a/dhcpv6/nclient6/client_test.go b/dhcpv6/nclient6/client_test.go
new file mode 100644
index 0000000..cba4ef8
--- /dev/null
+++ b/dhcpv6/nclient6/client_test.go
@@ -0,0 +1,258 @@
+// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+package nclient6
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/hugelgupf/socketpair"
+ "github.com/insomniacslk/dhcp/dhcpv6"
+ "github.com/insomniacslk/dhcp/dhcpv6/server6"
+)
+
+type handler struct {
+ mu sync.Mutex
+ received []*dhcpv6.Message
+
+ // Each received packet can have more than one response (in theory,
+ // from different servers sending different Advertise, for example).
+ responses [][]*dhcpv6.Message
+}
+
+func (h *handler) handle(conn net.PacketConn, peer net.Addr, msg dhcpv6.DHCPv6) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ m := msg.(*dhcpv6.Message)
+
+ h.received = append(h.received, m)
+
+ if len(h.responses) > 0 {
+ resps := h.responses[0]
+ // What should we send in response?
+ for _, resp := range resps {
+ conn.WriteTo(resp.ToBytes(), peer)
+ }
+ h.responses = h.responses[1:]
+ }
+}
+
+func serveAndClient(ctx context.Context, responses [][]*dhcpv6.Message, opt ...ClientOpt) (*Client, net.PacketConn) {
+ // Fake connection between client and server. No raw sockets, no port
+ // weirdness.
+ clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
+ if err != nil {
+ panic(err)
+ }
+
+ o := []ClientOpt{WithConn(clientRawConn), WithRetry(1), WithTimeout(2 * time.Second)}
+ o = append(o, opt...)
+ mc, err := New(net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
+ if err != nil {
+ panic(err)
+ }
+
+ h := &handler{
+ responses: responses,
+ }
+ s, err := server6.NewServer(nil, h.handle, server6.WithConn(serverRawConn))
+ if err != nil {
+ panic(err)
+ }
+ go s.Serve()
+
+ return mc, serverRawConn
+}
+
+func ComparePacket(got *dhcpv6.Message, want *dhcpv6.Message) error {
+ if got == nil && got == want {
+ return nil
+ }
+ if (want == nil || got == nil) && (got != want) {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ if bytes.Compare(got.ToBytes(), want.ToBytes()) != 0 {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ return nil
+}
+
+func pktsExpected(got []*dhcpv6.Message, want []*dhcpv6.Message) error {
+ if len(got) != len(want) {
+ return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
+ }
+
+ for i := range got {
+ if err := ComparePacket(got[i], want[i]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func newPacket(xid dhcpv6.TransactionID) *dhcpv6.Message {
+ p, err := dhcpv6.NewMessage()
+ if err != nil {
+ panic(fmt.Sprintf("newpacket: %v", err))
+ }
+ p.TransactionID = xid
+ return p
+}
+
+func TestSendAndReadUntil(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send *dhcpv6.Message
+ server []*dhcpv6.Message
+
+ // If want is nil, we assume server contains what is wanted.
+ want *dhcpv6.Message
+ wantErr error
+ }{
+ {
+ desc: "two response packets",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet, one invalid XID",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x77, 0x33, 0x33}),
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "discard wrong XID",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0, 0, 0}),
+ },
+ want: nil,
+ wantErr: ErrNoResponse,
+ },
+ {
+ desc: "no response, timeout",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ wantErr: ErrNoResponse,
+ },
+ } {
+ t.Run(tt.desc, func(t *testing.T) {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, [][]*dhcpv6.Message{tt.server},
+ // Use an unbuffered channel to make sure we
+ // have no deadlocks.
+ withBufferCap(0))
+ defer mc.Close()
+
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, tt.send, nil)
+ if err != tt.wantErr {
+ t.Error(err)
+ }
+
+ if err := ComparePacket(rcvd, tt.want); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ })
+ }
+}
+
+func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
+ pkt := newPacket([3]byte{0x33, 0x33, 0x33})
+
+ responses := []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ }
+
+ // Both the server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, udpConn := serveAndClient(ctx, [][]*dhcpv6.Message{responses})
+ defer mc.Close()
+
+ // Too short for valid DHCPv4 packet.
+ udpConn.WriteTo([]byte{0x01}, nil)
+
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, pkt, nil)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if err := ComparePacket(rcvd, responses[0]); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+}
+
+func TestMultipleSendAndReadOne(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send []*dhcpv6.Message
+ server [][]*dhcpv6.Message
+ wantErr []error
+ }{
+ {
+ desc: "two requests, two responses",
+ send: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ newPacket([3]byte{0x44, 0x44, 0x44}),
+ },
+ server: [][]*dhcpv6.Message{
+ []*dhcpv6.Message{ // Response for first packet.
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ []*dhcpv6.Message{ // Response for second packet.
+ newPacket([3]byte{0x44, 0x44, 0x44}),
+ },
+ },
+ wantErr: []error{
+ nil,
+ nil,
+ },
+ },
+ } {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, tt.server)
+ defer mc.conn.Close()
+
+ for i, send := range tt.send {
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, send, nil)
+
+ if wantErr := tt.wantErr[i]; err != wantErr {
+ t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr)
+ }
+ if err := pktsExpected([]*dhcpv6.Message{rcvd}, tt.server[i]); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ }
+ }
+}