summaryrefslogtreecommitdiffhomepage
path: root/dhcpv6
diff options
context:
space:
mode:
authorChristopher Koch <chrisko@google.com>2019-03-04 12:47:50 -0800
committerinsomniac <insomniacslk@users.noreply.github.com>2019-04-04 12:49:08 +0100
commit175868d67987770d2d729186f7676e0b98f925df (patch)
tree2d56b1cb28fddd14ad935fa3c41dcc368b31e30c /dhcpv6
parentb40bd52ae58aee37cec9ef81008e24488350c98f (diff)
client6: new async DHCPv6 client like #250.
- Race-condition-averse. - Supports multiple concurrent requests. - Tested. - Requires a fully compatible net.PacketConn. Signed-off-by: Christopher Koch <chrisko@google.com>
Diffstat (limited to 'dhcpv6')
-rw-r--r--dhcpv6/client6/client.go6
-rw-r--r--dhcpv6/dhcpv6.go9
-rw-r--r--dhcpv6/dhcpv6message.go10
-rw-r--r--dhcpv6/modifiers.go5
-rw-r--r--dhcpv6/nclient6/client.go371
-rw-r--r--dhcpv6/nclient6/client_test.go258
-rw-r--r--dhcpv6/server6/server.go119
-rw-r--r--dhcpv6/server6/server_test.go69
8 files changed, 722 insertions, 125 deletions
diff --git a/dhcpv6/client6/client.go b/dhcpv6/client6/client.go
index b1f8e11..c7bd318 100644
--- a/dhcpv6/client6/client.go
+++ b/dhcpv6/client6/client.go
@@ -199,7 +199,11 @@ func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType d
// an error if any. The modifiers will be applied to the Solicit before sending
// it, see modifiers.go
func (c *Client) Solicit(ifname string, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) {
- solicit, err := dhcpv6.NewSolicitForInterface(ifname)
+ iface, err := net.InterfaceByName(ifname)
+ if err != nil {
+ return nil, nil, err
+ }
+ solicit, err := dhcpv6.NewSolicit(iface.HardwareAddr)
if err != nil {
return nil, nil, err
}
diff --git a/dhcpv6/dhcpv6.go b/dhcpv6/dhcpv6.go
index c8c4292..7505dfc 100644
--- a/dhcpv6/dhcpv6.go
+++ b/dhcpv6/dhcpv6.go
@@ -16,6 +16,11 @@ type DHCPv6 interface {
String() string
Summary() string
IsRelay() bool
+
+ // GetInnerMessage returns the innermost encapsulated DHCPv6 message.
+ //
+ // If it is already a message, it will be returned. If it is a relay
+ // message, the encapsulated message will be recursively extracted.
GetInnerMessage() (*Message, error)
GetOption(code OptionCode) []Option
@@ -108,11 +113,11 @@ func DecapsulateRelay(l DHCPv6) (DHCPv6, error) {
}
opt := l.GetOneOption(OptionRelayMsg)
if opt == nil {
- return nil, fmt.Errorf("No OptRelayMsg found")
+ return nil, fmt.Errorf("malformed Relay message: no OptRelayMsg found")
}
relayOpt := opt.(*OptRelayMsg)
if relayOpt.RelayMessage() == nil {
- return nil, fmt.Errorf("Relay message cannot be nil")
+ return nil, fmt.Errorf("malformed Relay message: encapsulated message is empty")
}
return relayOpt.RelayMessage(), nil
}
diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go
index 9e03fa7..9237c1b 100644
--- a/dhcpv6/dhcpv6message.go
+++ b/dhcpv6/dhcpv6message.go
@@ -70,18 +70,14 @@ func NewSolicitWithCID(duid Duid, modifiers ...Modifier) (*Message, error) {
return m, nil
}
-// NewSolicitForInterface creates a new SOLICIT message with DUID-LLT, using the
+// NewSolicit creates a new SOLICIT message with DUID-LLT, using the
// given network interface's hardware address and current time
-func NewSolicitForInterface(ifname string, modifiers ...Modifier) (*Message, error) {
- iface, err := net.InterfaceByName(ifname)
- if err != nil {
- return nil, err
- }
+func NewSolicit(ifaceHWAddr net.HardwareAddr, modifiers ...Modifier) (*Message, error) {
duid := Duid{
Type: DUID_LLT,
HwType: iana.HWTypeEthernet,
Time: GetTime(),
- LinkLayerAddr: iface.HardwareAddr,
+ LinkLayerAddr: ifaceHWAddr,
}
return NewSolicitWithCID(duid, modifiers...)
}
diff --git a/dhcpv6/modifiers.go b/dhcpv6/modifiers.go
index 8c75ea5..eaa370d 100644
--- a/dhcpv6/modifiers.go
+++ b/dhcpv6/modifiers.go
@@ -99,6 +99,11 @@ func WithDomainSearchList(searchlist ...string) Modifier {
}
}
+// WithRapidCommit adds the rapid commit option to a message.
+func WithRapidCommit(d DHCPv6) {
+ d.UpdateOption(&OptionGeneric{OptionCode: OptionRapidCommit})
+}
+
// WithRequestedOptions adds requested options to the packet
func WithRequestedOptions(optionCodes ...OptionCode) Modifier {
return func(d DHCPv6) {
diff --git a/dhcpv6/nclient6/client.go b/dhcpv6/nclient6/client.go
new file mode 100644
index 0000000..dc5dd33
--- /dev/null
+++ b/dhcpv6/nclient6/client.go
@@ -0,0 +1,371 @@
+// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package nclient6
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv6"
+)
+
+// Broadcast destination IP addresses as defined by RFC 3315
+var (
+ AllDHCPRelayAgentsAndServers = &net.UDPAddr{
+ IP: net.ParseIP("ff02::1:2"),
+ Port: dhcpv6.DefaultServerPort,
+ }
+ AllDHCPServers = &net.UDPAddr{
+ IP: net.ParseIP("ff05::1:3"),
+ Port: dhcpv6.DefaultServerPort,
+ }
+)
+
+const (
+ maxUDPReceivedPacketSize = 8192
+ maxMessageSize = 1500
+)
+
+var (
+ // ErrNoResponse is returned when no response packet is received.
+ ErrNoResponse = errors.New("no matching response packet received")
+)
+
+// pendingCh is a channel associated with a pending TransactionID.
+type pendingCh struct {
+ // SendAndRead closes done to indicate that it wishes for no more
+ // messages for this particular XID.
+ done <-chan struct{}
+
+ // ch is used by the receive loop to distribute DHCP messages.
+ ch chan<- *dhcpv6.Message
+}
+
+// Client is a DHCPv6 client.
+type Client struct {
+ ifaceHWAddr net.HardwareAddr
+ conn net.PacketConn
+ timeout time.Duration
+ retry int
+
+ // bufferCap is the channel capacity for each TransactionID.
+ bufferCap int
+
+ // serverAddr is the UDP address to send all packets to.
+ //
+ // This may be an actual broadcast address, or a unicast address.
+ serverAddr *net.UDPAddr
+
+ // closed is an atomic bool set to 1 when done is closed.
+ closed uint32
+
+ // done is closed to unblock the receive loop.
+ done chan struct{}
+
+ // wg protects the receiveLoop.
+ wg sync.WaitGroup
+
+ pendingMu sync.Mutex
+ // pending stores the distribution channels for each pending
+ // TransactionID. receiveLoop uses this map to determine which channel
+ // to send a new DHCP message to.
+ pending map[dhcpv6.TransactionID]*pendingCh
+}
+
+// NewIPv6UDPConn returns a UDP connection bound to both the interface and port
+// given based on a IPv6 DGRAM socket.
+func NewIPv6UDPConn(iface string, port int) (net.PacketConn, error) {
+ return net.ListenUDP("udp6", &net.UDPAddr{
+ Port: port,
+ Zone: iface,
+ })
+}
+
+// New creates a new DHCP client that sends and receives packets on the given
+// interface.
+func New(ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
+ c := &Client{
+ ifaceHWAddr: ifaceHWAddr,
+ timeout: 5 * time.Second,
+ retry: 3,
+ serverAddr: AllDHCPServers,
+ bufferCap: 5,
+
+ done: make(chan struct{}),
+ pending: make(map[dhcpv6.TransactionID]*pendingCh),
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ if c.conn == nil {
+ return nil, fmt.Errorf("require a connection")
+ }
+
+ c.receiveLoop()
+ return c, nil
+}
+
+// Close closes the underlying connection.
+func (c *Client) Close() error {
+ // Make sure not to close done twice.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return nil
+ }
+
+ err := c.conn.Close()
+
+ // Closing c.done sets off a chain reaction:
+ //
+ // Any SendAndRead unblocks trying to receive more messages, which
+ // means rem() gets called.
+ //
+ // rem() should be unblocking receiveLoop if it is blocked.
+ //
+ // receiveLoop should then exit gracefully.
+ close(c.done)
+
+ // Wait for receiveLoop to stop.
+ c.wg.Wait()
+
+ return err
+}
+
+func isErrClosing(err error) bool {
+ // Unfortunately, the epoll-connection-closed error is internal to the
+ // net library.
+ return strings.Contains(err.Error(), "use of closed network connection")
+}
+
+func (c *Client) receiveLoop() {
+ c.wg.Add(1)
+ go func() {
+ defer c.wg.Done()
+ for {
+ // TODO: Clients can send a "max packet size" option in their
+ // packets, IIRC. Choose a reasonable size and set it.
+ b := make([]byte, 1500)
+ n, _, err := c.conn.ReadFrom(b)
+ if err != nil {
+ if !isErrClosing(err) {
+ log.Printf("error reading from UDP connection: %v", err)
+ }
+ return
+ }
+
+ msg, err := dhcpv6.MessageFromBytes(b[:n])
+ if err != nil {
+ // Not a valid DHCP packet; keep listening.
+ continue
+ }
+
+ c.pendingMu.Lock()
+ p, ok := c.pending[msg.TransactionID]
+ if ok {
+ select {
+ case <-p.done:
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+
+ // This send may block.
+ case p.ch <- msg:
+ }
+ }
+ c.pendingMu.Unlock()
+ }
+ }()
+}
+
+// ClientOpt is a function that configures the Client.
+type ClientOpt func(*Client)
+
+func withBufferCap(n int) ClientOpt {
+ return func(c *Client) {
+ c.bufferCap = n
+ }
+}
+
+// WithTimeout configures the retransmission timeout.
+//
+// Default is 5 seconds.
+func WithTimeout(d time.Duration) ClientOpt {
+ return func(c *Client) {
+ c.timeout = d
+ }
+}
+
+// WithRetry configures the number of retransmissions to attempt.
+//
+// Default is 3.
+func WithRetry(r int) ClientOpt {
+ return func(c *Client) {
+ c.retry = r
+ }
+}
+
+// WithConn configures the packet connection to use.
+func WithConn(conn net.PacketConn) ClientOpt {
+ return func(c *Client) {
+ c.conn = conn
+ }
+}
+
+// WithBroadcastAddr configures the address to broadcast to.
+func WithBroadcastAddr(n *net.UDPAddr) ClientOpt {
+ return func(c *Client) {
+ c.serverAddr = n
+ }
+}
+
+// Matcher matches DHCP packets.
+type Matcher func(*dhcpv6.Message) bool
+
+// IsMessageType returns a matcher that checks for the message type.
+//
+// If t is MessageTypeNone, all packets are matched.
+func IsMessageType(t dhcpv6.MessageType) Matcher {
+ return func(p *dhcpv6.Message) bool {
+ return p.MessageType == t || t == dhcpv6.MessageTypeNone
+ }
+}
+
+// Solicit sends a solicitation message and returns the first valid
+// advertisement received.
+func (c *Client) Solicit(ctx context.Context, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) {
+ solicit, err := dhcpv6.NewSolicit(c.ifaceHWAddr, modifiers...)
+ if err != nil {
+ return nil, err
+ }
+ msg, err := c.SendAndRead(ctx, c.serverAddr, solicit, IsMessageType(dhcpv6.MessageTypeAdvertise))
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+// Request requests an IP Assignment from peer given an advertise message.
+func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) {
+ request, err := dhcpv6.NewRequestFromAdvertise(advertise, modifiers...)
+ if err != nil {
+ return nil, err
+ }
+ return c.SendAndRead(ctx, c.serverAddr, request, nil)
+}
+
+// send sends p to destination and returns a response channel.
+//
+// The returned function must be called after all desired responses have been
+// received.
+//
+// Responses will be matched by transaction ID.
+func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) {
+ c.pendingMu.Lock()
+ if _, ok := c.pending[msg.TransactionID]; ok {
+ c.pendingMu.Unlock()
+ return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ }
+
+ ch := make(chan *dhcpv6.Message, c.bufferCap)
+ done := make(chan struct{})
+ c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
+ c.pendingMu.Unlock()
+
+ cancel := func() {
+ // Why can't we just close ch here?
+ //
+ // Because receiveLoop may potentially be blocked trying to
+ // send on ch. We gotta unblock it first, so it'll unlock the
+ // lock, and then we can take the lock and remove the XID from
+ // the pending transaction map.
+ close(done)
+
+ c.pendingMu.Lock()
+ if p, ok := c.pending[msg.TransactionID]; ok {
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+ }
+ c.pendingMu.Unlock()
+ }
+
+ if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
+ cancel()
+ return nil, nil, fmt.Errorf("error writing packet to connection: %v", err)
+ }
+ return ch, cancel, nil
+}
+
+// This should never be visible to a user.
+var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
+
+// SendAndRead sends a packet p to a destination dest and waits for the first
+// response matching `match` as well as its Transaction ID.
+//
+// If match is nil, the first packet matching the Transaction ID is returned.
+func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) {
+ var response *dhcpv6.Message
+ err := c.retryFn(func(timeout time.Duration) error {
+ ch, rem, err := c.send(dest, msg)
+ if err != nil {
+ return err
+ }
+ defer rem()
+
+ for {
+ select {
+ case <-c.done:
+ return ErrNoResponse
+
+ case <-time.After(timeout):
+ return errDeadlineExceeded
+
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case packet := <-ch:
+ if match == nil || match(packet) {
+ response = packet
+ return nil
+ }
+ }
+ }
+ })
+ if err == errDeadlineExceeded {
+ return nil, ErrNoResponse
+ }
+ if err != nil {
+ return nil, err
+ }
+ return response, nil
+}
+
+func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
+ timeout := c.timeout
+
+ // Each retry takes the amount of timeout at worst.
+ for i := 0; i < c.retry || c.retry < 0; i++ {
+ switch err := fn(timeout); err {
+ case nil:
+ // Got it!
+ return nil
+
+ case errDeadlineExceeded:
+ // Double timeout, then retry.
+ timeout *= 2
+
+ default:
+ return err
+ }
+ }
+
+ return errDeadlineExceeded
+}
diff --git a/dhcpv6/nclient6/client_test.go b/dhcpv6/nclient6/client_test.go
new file mode 100644
index 0000000..cba4ef8
--- /dev/null
+++ b/dhcpv6/nclient6/client_test.go
@@ -0,0 +1,258 @@
+// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+package nclient6
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/hugelgupf/socketpair"
+ "github.com/insomniacslk/dhcp/dhcpv6"
+ "github.com/insomniacslk/dhcp/dhcpv6/server6"
+)
+
+type handler struct {
+ mu sync.Mutex
+ received []*dhcpv6.Message
+
+ // Each received packet can have more than one response (in theory,
+ // from different servers sending different Advertise, for example).
+ responses [][]*dhcpv6.Message
+}
+
+func (h *handler) handle(conn net.PacketConn, peer net.Addr, msg dhcpv6.DHCPv6) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ m := msg.(*dhcpv6.Message)
+
+ h.received = append(h.received, m)
+
+ if len(h.responses) > 0 {
+ resps := h.responses[0]
+ // What should we send in response?
+ for _, resp := range resps {
+ conn.WriteTo(resp.ToBytes(), peer)
+ }
+ h.responses = h.responses[1:]
+ }
+}
+
+func serveAndClient(ctx context.Context, responses [][]*dhcpv6.Message, opt ...ClientOpt) (*Client, net.PacketConn) {
+ // Fake connection between client and server. No raw sockets, no port
+ // weirdness.
+ clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
+ if err != nil {
+ panic(err)
+ }
+
+ o := []ClientOpt{WithConn(clientRawConn), WithRetry(1), WithTimeout(2 * time.Second)}
+ o = append(o, opt...)
+ mc, err := New(net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
+ if err != nil {
+ panic(err)
+ }
+
+ h := &handler{
+ responses: responses,
+ }
+ s, err := server6.NewServer(nil, h.handle, server6.WithConn(serverRawConn))
+ if err != nil {
+ panic(err)
+ }
+ go s.Serve()
+
+ return mc, serverRawConn
+}
+
+func ComparePacket(got *dhcpv6.Message, want *dhcpv6.Message) error {
+ if got == nil && got == want {
+ return nil
+ }
+ if (want == nil || got == nil) && (got != want) {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ if bytes.Compare(got.ToBytes(), want.ToBytes()) != 0 {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ return nil
+}
+
+func pktsExpected(got []*dhcpv6.Message, want []*dhcpv6.Message) error {
+ if len(got) != len(want) {
+ return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
+ }
+
+ for i := range got {
+ if err := ComparePacket(got[i], want[i]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func newPacket(xid dhcpv6.TransactionID) *dhcpv6.Message {
+ p, err := dhcpv6.NewMessage()
+ if err != nil {
+ panic(fmt.Sprintf("newpacket: %v", err))
+ }
+ p.TransactionID = xid
+ return p
+}
+
+func TestSendAndReadUntil(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send *dhcpv6.Message
+ server []*dhcpv6.Message
+
+ // If want is nil, we assume server contains what is wanted.
+ want *dhcpv6.Message
+ wantErr error
+ }{
+ {
+ desc: "two response packets",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet, one invalid XID",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0x77, 0x33, 0x33}),
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ want: newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "discard wrong XID",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ server: []*dhcpv6.Message{
+ newPacket([3]byte{0, 0, 0}),
+ },
+ want: nil,
+ wantErr: ErrNoResponse,
+ },
+ {
+ desc: "no response, timeout",
+ send: newPacket([3]byte{0x33, 0x33, 0x33}),
+ wantErr: ErrNoResponse,
+ },
+ } {
+ t.Run(tt.desc, func(t *testing.T) {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, [][]*dhcpv6.Message{tt.server},
+ // Use an unbuffered channel to make sure we
+ // have no deadlocks.
+ withBufferCap(0))
+ defer mc.Close()
+
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, tt.send, nil)
+ if err != tt.wantErr {
+ t.Error(err)
+ }
+
+ if err := ComparePacket(rcvd, tt.want); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ })
+ }
+}
+
+func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
+ pkt := newPacket([3]byte{0x33, 0x33, 0x33})
+
+ responses := []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ }
+
+ // Both the server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, udpConn := serveAndClient(ctx, [][]*dhcpv6.Message{responses})
+ defer mc.Close()
+
+ // Too short for valid DHCPv4 packet.
+ udpConn.WriteTo([]byte{0x01}, nil)
+
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, pkt, nil)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if err := ComparePacket(rcvd, responses[0]); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+}
+
+func TestMultipleSendAndReadOne(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send []*dhcpv6.Message
+ server [][]*dhcpv6.Message
+ wantErr []error
+ }{
+ {
+ desc: "two requests, two responses",
+ send: []*dhcpv6.Message{
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ newPacket([3]byte{0x44, 0x44, 0x44}),
+ },
+ server: [][]*dhcpv6.Message{
+ []*dhcpv6.Message{ // Response for first packet.
+ newPacket([3]byte{0x33, 0x33, 0x33}),
+ },
+ []*dhcpv6.Message{ // Response for second packet.
+ newPacket([3]byte{0x44, 0x44, 0x44}),
+ },
+ },
+ wantErr: []error{
+ nil,
+ nil,
+ },
+ },
+ } {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, tt.server)
+ defer mc.conn.Close()
+
+ for i, send := range tt.send {
+ rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, send, nil)
+
+ if wantErr := tt.wantErr[i]; err != wantErr {
+ t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr)
+ }
+ if err := pktsExpected([]*dhcpv6.Message{rcvd}, tt.server[i]); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ }
+ }
+}
diff --git a/dhcpv6/server6/server.go b/dhcpv6/server6/server.go
index 6a6d0b7..f6fb826 100644
--- a/dhcpv6/server6/server.go
+++ b/dhcpv6/server6/server.go
@@ -1,11 +1,8 @@
package server6
import (
- "fmt"
"log"
"net"
- "sync"
- "time"
"github.com/insomniacslk/dhcp/dhcpv6"
)
@@ -16,102 +13,66 @@ type Handler func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6)
// Server represents a DHCPv6 server object
type Server struct {
- conn net.PacketConn
- connMutex sync.Mutex
- shouldStop chan bool
- Handler Handler
- localAddr net.UDPAddr
+ conn net.PacketConn
+ handler Handler
}
-// LocalAddr returns the local address of the listening socket, or nil if not
-// listening
-func (s *Server) LocalAddr() net.Addr {
- s.connMutex.Lock()
- defer s.connMutex.Unlock()
- if s.conn == nil {
- return nil
- }
- return s.conn.LocalAddr()
-}
-
-// ActivateAndServe starts the DHCPv6 server. The listener will run in
-// background, and can be interrupted with `Server.Close`.
-func (s *Server) ActivateAndServe() error {
- s.connMutex.Lock()
- if s.conn != nil {
- // this may panic if s.conn is closed but not reset properly. For that
- // you should use `Server.Close`.
- s.Close()
- }
- conn, err := net.ListenUDP("udp6", &s.localAddr)
- if err != nil {
- s.connMutex.Unlock()
- return err
- }
- s.conn = conn
- s.connMutex.Unlock()
- var (
- pc *net.UDPConn
- ok bool
- )
- if pc, ok = s.conn.(*net.UDPConn); !ok {
- return fmt.Errorf("error: not an UDPConn")
- }
- if pc == nil {
- return fmt.Errorf("ActivateAndServe: invalid nil PacketConn")
- }
- log.Printf("Server listening on %s", pc.LocalAddr())
+// Serve starts the DHCPv6 server. The listener will run in background, and can
+// be interrupted with `Server.Close`.
+func (s *Server) Serve() {
+ log.Printf("Server listening on %s", s.conn.LocalAddr())
log.Print("Ready to handle requests")
+
for {
- select {
- case <-s.shouldStop:
- break
- case <-time.After(time.Millisecond):
- }
- pc.SetReadDeadline(time.Now().Add(time.Second))
rbuf := make([]byte, 4096) // FIXME this is bad
- n, peer, err := pc.ReadFrom(rbuf)
+ n, peer, err := s.conn.ReadFrom(rbuf)
if err != nil {
- switch err.(type) {
- case net.Error:
- if !err.(net.Error).Timeout() {
- return err
- }
- // if timeout, silently skip and continue
- default:
- // complain and continue
- log.Printf("Error reading from packet conn: %v", err)
- }
- continue
+ log.Printf("Error reading from packet conn: %v", err)
+ return
}
log.Printf("Handling request from %v", peer)
- m, err := dhcpv6.FromBytes(rbuf[:n])
+
+ d, err := dhcpv6.FromBytes(rbuf[:n])
if err != nil {
log.Printf("Error parsing DHCPv6 request: %v", err)
continue
}
- go s.Handler(pc, peer, m)
+
+ go s.handler(s.conn, peer, d)
}
}
// Close sends a termination request to the server, and closes the UDP listener
func (s *Server) Close() error {
- s.shouldStop <- true
- s.connMutex.Lock()
- defer s.connMutex.Unlock()
- if s.conn != nil {
- ret := s.conn.Close()
- s.conn = nil
- return ret
+ return s.conn.Close()
+}
+
+// A ServerOpt configures a Server.
+type ServerOpt func(s *Server)
+
+// WithConn configures a server with the given connection.
+func WithConn(conn net.PacketConn) ServerOpt {
+ return func(s *Server) {
+ s.conn = conn
}
- return nil
}
// NewServer initializes and returns a new Server object
-func NewServer(addr net.UDPAddr, handler Handler) *Server {
- return &Server{
- localAddr: addr,
- Handler: handler,
- shouldStop: make(chan bool, 1),
+func NewServer(addr *net.UDPAddr, handler Handler, opt ...ServerOpt) (*Server, error) {
+ s := &Server{
+ handler: handler,
+ }
+
+ for _, o := range opt {
+ o(s)
+ }
+
+ if s.conn == nil {
+ conn, err := net.ListenUDP("udp6", addr)
+ if err != nil {
+ return nil, err
+ }
+ s.conn = conn
}
+ return s, nil
}
diff --git a/dhcpv6/server6/server_test.go b/dhcpv6/server6/server_test.go
index 3d2a365..05d62cb 100644
--- a/dhcpv6/server6/server_test.go
+++ b/dhcpv6/server6/server_test.go
@@ -1,62 +1,58 @@
package server6
import (
+ "context"
"log"
"net"
"testing"
- "time"
"github.com/insomniacslk/dhcp/dhcpv6"
- "github.com/insomniacslk/dhcp/dhcpv6/client6"
+ "github.com/insomniacslk/dhcp/dhcpv6/nclient6"
"github.com/insomniacslk/dhcp/interfaces"
"github.com/stretchr/testify/require"
)
+type fakeUnconnectedConn struct {
+ *net.UDPConn
+}
+
+func (f fakeUnconnectedConn) WriteTo(b []byte, _ net.Addr) (int, error) {
+ return f.UDPConn.Write(b)
+}
+
+func (f fakeUnconnectedConn) ReadFrom(b []byte) (int, net.Addr, error) {
+ n, err := f.Read(b)
+ return n, nil, err
+}
+
// utility function to set up a client and a server instance and run it in
// background. The caller needs to call Server.Close() once finished.
-func setUpClientAndServer(handler Handler) (*client6.Client, *Server) {
- laddr := net.UDPAddr{
+func setUpClientAndServer(handler Handler) (*nclient6.Client, *Server) {
+ laddr := &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: 0,
}
- s := NewServer(laddr, handler)
- go s.ActivateAndServe()
-
- c := client6.NewClient()
- c.LocalAddr = &net.UDPAddr{
- IP: net.ParseIP("::1"),
- }
- for {
- if s.LocalAddr() != nil {
- break
- }
- time.Sleep(10 * time.Millisecond)
- log.Printf("Waiting for server to run...")
- }
- c.RemoteAddr = &net.UDPAddr{
- IP: net.ParseIP("::1"),
- Port: s.LocalAddr().(*net.UDPAddr).Port,
+ s, err := NewServer(laddr, handler)
+ if err != nil {
+ panic(err)
}
+ go s.Serve()
- return c, s
-}
+ clientConn, err := net.DialUDP("udp6", &net.UDPAddr{IP: net.ParseIP("::1")}, s.conn.LocalAddr().(*net.UDPAddr))
+ if err != nil {
+ panic(err)
+ }
-func TestNewServer(t *testing.T) {
- laddr := net.UDPAddr{
- IP: net.ParseIP("::1"),
- Port: 0,
+ c, err := nclient6.New(net.HardwareAddr{1, 2, 3, 4, 5, 6},
+ nclient6.WithConn(fakeUnconnectedConn{clientConn}))
+ if err != nil {
+ panic(err)
}
- handler := func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) {}
- s := NewServer(laddr, handler)
- defer s.Close()
- require.NotNil(t, s)
- require.Nil(t, s.conn)
- require.Equal(t, laddr, s.localAddr)
- require.NotNil(t, s.Handler)
+ return c, s
}
-func TestServerActivateAndServe(t *testing.T) {
+func TestServer(t *testing.T) {
handler := func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) {
msg := m.(*dhcpv6.Message)
adv, err := dhcpv6.NewAdvertiseFromSolicit(msg)
@@ -68,6 +64,7 @@ func TestServerActivateAndServe(t *testing.T) {
log.Printf("Cannot reply to client: %v", err)
}
}
+
c, s := setUpClientAndServer(handler)
defer s.Close()
@@ -75,6 +72,6 @@ func TestServerActivateAndServe(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, 0, len(ifaces))
- _, _, err = c.Solicit(ifaces[0].Name)
+ _, err = c.Solicit(context.Background(), dhcpv6.WithRapidCommit)
require.NoError(t, err)
}