summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv4/dhcpv4.go6
-rw-r--r--dhcpv4/nclient4/client.go412
-rw-r--r--dhcpv4/nclient4/client_test.go326
-rw-r--r--dhcpv4/nclient4/conn_linux.go173
-rw-r--r--dhcpv4/nclient4/ipv4.go376
-rw-r--r--dhcpv4/server4/server.go117
-rw-r--r--dhcpv4/server4/server_test.go77
7 files changed, 1366 insertions, 121 deletions
diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go
index 0373764..f74ed40 100644
--- a/dhcpv4/dhcpv4.go
+++ b/dhcpv4/dhcpv4.go
@@ -225,13 +225,17 @@ func NewRequestFromOffer(offer *DHCPv4, modifiers ...Modifier) (*DHCPv4, error)
// find server IP address
serverIP := offer.ServerIdentifier()
if serverIP == nil {
- return nil, errors.New("Missing Server IP Address in DHCP Offer")
+ if offer.ServerIPAddr == nil || offer.ServerIPAddr.IsUnspecified() {
+ return nil, fmt.Errorf("missing Server IP Address in DHCP Offer")
+ }
+ serverIP = offer.ServerIPAddr
}
return New(PrependModifiers(modifiers,
WithReply(offer),
WithMessageType(MessageTypeRequest),
WithServerIP(serverIP),
+ WithClientIP(offer.ClientIPAddr),
WithOption(OptRequestedIPAddress(offer.YourIPAddr)),
WithOption(OptServerIdentifier(serverIP)),
)...)
diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go
new file mode 100644
index 0000000..3c97a60
--- /dev/null
+++ b/dhcpv4/nclient4/client.go
@@ -0,0 +1,412 @@
+// Copyright 2018 the u-root Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+// Package nclient4 is a small, minimum-functionality client for DHCPv4.
+//
+// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as
+// well as the Request-Ack renewal process.
+package nclient4
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+)
+
+const (
+ defaultTimeout = 5 * time.Second
+ defaultRetries = 3
+ defaultBufferCap = 5
+ maxMessageSize = 1500
+
+ // ClientPort is the port that DHCP clients listen on.
+ ClientPort = 68
+
+ // ServerPort is the port that DHCP servers and relay agents listen on.
+ ServerPort = 67
+)
+
+var (
+ // DefaultServers is the address of all link-local DHCP servers and
+ // relay agents.
+ DefaultServers = &net.UDPAddr{
+ IP: net.IPv4bcast,
+ Port: ServerPort,
+ }
+)
+
+var (
+ // ErrNoResponse is returned when no response packet is received.
+ ErrNoResponse = errors.New("no matching response packet received")
+)
+
+// pendingCh is a channel associated with a pending TransactionID.
+type pendingCh struct {
+ // SendAndRead closes done to indicate that it wishes for no more
+ // messages for this particular XID.
+ done <-chan struct{}
+
+ // ch is used by the receive loop to distribute DHCP messages.
+ ch chan<- *dhcpv4.DHCPv4
+}
+
+// Client is an IPv4 DHCP client.
+type Client struct {
+ ifaceHWAddr net.HardwareAddr
+ conn net.PacketConn
+ timeout time.Duration
+ retry int
+
+ // bufferCap is the channel capacity for each TransactionID.
+ bufferCap int
+
+ // serverAddr is the UDP address to send all packets to.
+ //
+ // This may be an actual broadcast address, or a unicast address.
+ serverAddr *net.UDPAddr
+
+ // closed is an atomic bool set to 1 when done is closed.
+ closed uint32
+
+ // done is closed to unblock the receive loop.
+ done chan struct{}
+
+ // wg protects any spawned goroutines, namely the receiveLoop.
+ wg sync.WaitGroup
+
+ pendingMu sync.Mutex
+ // pending stores the distribution channels for each pending
+ // TransactionID. receiveLoop uses this map to determine which channel
+ // to send a new DHCP message to.
+ pending map[dhcpv4.TransactionID]*pendingCh
+}
+
+// New returns a client usable with an unconfigured interface.
+func New(ifaceName string, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
+ c := NewWithConn(nil, ifaceHWAddr, opts...)
+
+ // Do this after so that a caller can still use a WithConn to override
+ // the connection.
+ if c.conn == nil {
+ pc, err := NewRawUDPConn(ifaceName, ClientPort)
+ if err != nil {
+ return nil, err
+ }
+ c.conn = pc
+ }
+ return c, nil
+}
+
+// NewWithConn creates a new DHCP client that sends and receives packets on the
+// given interface.
+func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) *Client {
+ c := &Client{
+ ifaceHWAddr: ifaceHWAddr,
+ timeout: defaultTimeout,
+ retry: defaultRetries,
+ serverAddr: DefaultServers,
+ bufferCap: defaultBufferCap,
+ conn: conn,
+
+ done: make(chan struct{}),
+ pending: make(map[dhcpv4.TransactionID]*pendingCh),
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ c.wg.Add(1)
+ go c.receiveLoop()
+ return c
+}
+
+// Close closes the underlying connection.
+func (c *Client) Close() error {
+ // Make sure not to close done twice.
+ if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ return nil
+ }
+
+ err := c.conn.Close()
+
+ // Closing c.done sets off a chain reaction:
+ //
+ // Any SendAndRead unblocks trying to receive more messages, which
+ // means rem() gets called.
+ //
+ // rem() should be unblocking receiveLoop if it is blocked.
+ //
+ // receiveLoop should then exit gracefully.
+ close(c.done)
+
+ // Wait for receiveLoop to stop.
+ c.wg.Wait()
+
+ return err
+}
+
+func isErrClosing(err error) bool {
+ // Unfortunately, the epoll-connection-closed error is internal to the
+ // net library.
+ return strings.Contains(err.Error(), "use of closed network connection")
+}
+
+func (c *Client) receiveLoop() {
+ defer c.wg.Done()
+ for {
+ // TODO: Clients can send a "max packet size" option in their
+ // packets, IIRC. Choose a reasonable size and set it.
+ b := make([]byte, maxMessageSize)
+ n, _, err := c.conn.ReadFrom(b)
+ if err != nil {
+ if !isErrClosing(err) {
+ log.Printf("error reading from UDP connection: %v", err)
+ }
+ return
+ }
+
+ msg, err := dhcpv4.FromBytes(b[:n])
+ if err != nil {
+ // Not a valid DHCP packet; keep listening.
+ continue
+ }
+
+ if msg.OpCode != dhcpv4.OpcodeBootReply {
+ // Not a response message.
+ continue
+ }
+
+ // This is a somewhat non-standard check, by the looks
+ // of RFC 2131. It should work as long as the DHCP
+ // server is spec-compliant for the HWAddr field.
+ if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) {
+ // Not for us.
+ continue
+ }
+
+ c.pendingMu.Lock()
+ p, ok := c.pending[msg.TransactionID]
+ if ok {
+ select {
+ case <-p.done:
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+
+ // This send may block.
+ case p.ch <- msg:
+ }
+ }
+ c.pendingMu.Unlock()
+ }
+}
+
+// ClientOpt is a function that configures the Client.
+type ClientOpt func(*Client)
+
+// WithTimeout configures the retransmission timeout.
+//
+// Default is 5 seconds.
+func WithTimeout(d time.Duration) ClientOpt {
+ return func(c *Client) {
+ c.timeout = d
+ }
+}
+
+func withBufferCap(n int) ClientOpt {
+ return func(c *Client) {
+ c.bufferCap = n
+ }
+}
+
+// WithRetry configures the number of retransmissions to attempt.
+//
+// Default is 3.
+func WithRetry(r int) ClientOpt {
+ return func(c *Client) {
+ c.retry = r
+ }
+}
+
+// WithConn configures the packet connection to use.
+func WithConn(conn net.PacketConn) ClientOpt {
+ return func(c *Client) {
+ c.conn = conn
+ }
+}
+
+// WithServerAddr configures the address to send messages to.
+func WithServerAddr(n *net.UDPAddr) ClientOpt {
+ return func(c *Client) {
+ c.serverAddr = n
+ }
+}
+
+// Matcher matches DHCP packets.
+type Matcher func(*dhcpv4.DHCPv4) bool
+
+// IsMessageType returns a matcher that checks for the message type.
+//
+// If t is MessageTypeNone, all packets are matched.
+func IsMessageType(t dhcpv4.MessageType) Matcher {
+ return func(p *dhcpv4.DHCPv4) bool {
+ return p.MessageType() == t || t == dhcpv4.MessageTypeNone
+ }
+}
+
+// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
+// received.
+func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (*dhcpv4.DHCPv4, error) {
+ // RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should
+ // contain.
+ discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ if err != nil {
+ return nil, err
+ }
+ return c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
+}
+
+// Request completes the 4-way Discover-Offer-Request-Ack handshake.
+//
+// Note that modifiers will be applied *both* to Discover and Request packets.
+func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) {
+ offer, err = c.DiscoverOffer(ctx, modifiers...)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // TODO(chrisko): should this be unicast to the server?
+ req, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
+ dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(maxMessageSize)))...)
+ if err != nil {
+ return nil, nil, err
+ }
+ ack, err = c.SendAndRead(ctx, c.serverAddr, req, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ return offer, ack, nil
+}
+
+// send sends p to destination and returns a response channel.
+//
+// Responses will be matched by transaction ID and ClientHWAddr.
+//
+// The returned lambda function must be called after all desired responses have
+// been received in order to return the Transaction ID to the usable pool.
+func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) {
+ c.pendingMu.Lock()
+ if _, ok := c.pending[msg.TransactionID]; ok {
+ c.pendingMu.Unlock()
+ return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
+ }
+
+ ch := make(chan *dhcpv4.DHCPv4, c.bufferCap)
+ done := make(chan struct{})
+ c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
+ c.pendingMu.Unlock()
+
+ cancel = func() {
+ // Why can't we just close ch here?
+ //
+ // Because receiveLoop may potentially be blocked trying to
+ // send on ch. We gotta unblock it first, and then we can take
+ // the lock and remove the XID from the pending transaction
+ // map.
+ close(done)
+
+ c.pendingMu.Lock()
+ if p, ok := c.pending[msg.TransactionID]; ok {
+ close(p.ch)
+ delete(c.pending, msg.TransactionID)
+ }
+ c.pendingMu.Unlock()
+ }
+
+ if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
+ cancel()
+ return nil, nil, fmt.Errorf("error writing packet to connection: %v", err)
+ }
+ return ch, cancel, nil
+}
+
+// This error should never be visible to users.
+// It is used only to increase the timeout in retryFn.
+var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
+
+// SendAndRead sends a packet p to a destination dest and waits for the first
+// response matching `match` as well as its Transaction ID and ClientHWAddr.
+//
+// If match is nil, the first packet matching the Transaction ID and
+// ClientHWAddr is returned.
+func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) {
+ var response *dhcpv4.DHCPv4
+ err := c.retryFn(func(timeout time.Duration) error {
+ ch, rem, err := c.send(dest, p)
+ if err != nil {
+ return err
+ }
+ defer rem()
+
+ for {
+ select {
+ case <-c.done:
+ return ErrNoResponse
+
+ case <-time.After(timeout):
+ return errDeadlineExceeded
+
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case packet := <-ch:
+ if match == nil || match(packet) {
+ response = packet
+ return nil
+ }
+ }
+ }
+ })
+ if err == errDeadlineExceeded {
+ return nil, ErrNoResponse
+ }
+ if err != nil {
+ return nil, err
+ }
+ return response, nil
+}
+
+func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
+ timeout := c.timeout
+
+ // Each retry takes the amount of timeout at worst.
+ for i := 0; i < c.retry || c.retry < 0; i++ {
+ switch err := fn(timeout); err {
+ case nil:
+ // Got it!
+ return nil
+
+ case errDeadlineExceeded:
+ // Double timeout, then retry.
+ timeout *= 2
+
+ default:
+ return err
+ }
+ }
+
+ return errDeadlineExceeded
+}
diff --git a/dhcpv4/nclient4/client_test.go b/dhcpv4/nclient4/client_test.go
new file mode 100644
index 0000000..d3ea68b
--- /dev/null
+++ b/dhcpv4/nclient4/client_test.go
@@ -0,0 +1,326 @@
+// Copyright 2018 the u-root Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+package nclient4
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/hugelgupf/socketpair"
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/insomniacslk/dhcp/dhcpv4/server4"
+)
+
+type handler struct {
+ mu sync.Mutex
+ received []*dhcpv4.DHCPv4
+
+ // Each received packet can have more than one response (in theory,
+ // from different servers sending different Advertise, for example).
+ responses [][]*dhcpv4.DHCPv4
+}
+
+func (h *handler) handle(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ h.received = append(h.received, m)
+
+ if len(h.responses) > 0 {
+ for _, resp := range h.responses[0] {
+ conn.WriteTo(resp.ToBytes(), peer)
+ }
+ h.responses = h.responses[1:]
+ }
+}
+
+func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...ClientOpt) (*Client, net.PacketConn) {
+ // Fake PacketConn connection.
+ clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
+ if err != nil {
+ panic(err)
+ }
+
+ clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{IP: net.IPv4zero, Port: ClientPort})
+ serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort})
+
+ o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)}
+ o = append(o, opts...)
+ mc := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
+
+ h := &handler{responses: responses}
+ s, err := server4.NewServer(nil, h.handle, server4.WithConn(serverConn))
+ if err != nil {
+ panic(err)
+ }
+ go s.Serve()
+
+ return mc, serverConn
+}
+
+func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error {
+ if got == nil && got == want {
+ return nil
+ }
+ if (want == nil || got == nil) && (got != want) {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ if bytes.Compare(got.ToBytes(), want.ToBytes()) != 0 {
+ return fmt.Errorf("packet got %v, want %v", got, want)
+ }
+ return nil
+}
+
+func pktsExpected(got []*dhcpv4.DHCPv4, want []*dhcpv4.DHCPv4) error {
+ if len(got) != len(want) {
+ return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
+ }
+
+ for i := range got {
+ if err := ComparePacket(got[i], want[i]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func newPacketWeirdHWAddr(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
+ p, err := dhcpv4.New()
+ if err != nil {
+ panic(fmt.Sprintf("newpacket: %v", err))
+ }
+ p.OpCode = op
+ p.TransactionID = xid
+ p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 1, 2, 3, 4, 5, 6}
+ return p
+}
+
+func newPacket(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
+ p, err := dhcpv4.New()
+ if err != nil {
+ panic(fmt.Sprintf("newpacket: %v", err))
+ }
+ p.OpCode = op
+ p.TransactionID = xid
+ p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
+ return p
+}
+
+func TestSendAndRead(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send *dhcpv4.DHCPv4
+ server []*dhcpv4.DHCPv4
+
+ // If want is nil, we assume server[0] contains what is wanted.
+ want *dhcpv4.DHCPv4
+ wantErr error
+ }{
+ {
+ desc: "two response packets",
+ send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ server: []*dhcpv4.DHCPv4{
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet",
+ send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ server: []*dhcpv4.DHCPv4{
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "one response packet, one invalid XID, one invalid opcode, one invalid hwaddr",
+ send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ server: []*dhcpv4.DHCPv4{
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x77, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacketWeirdHWAddr(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ {
+ desc: "discard wrong XID",
+ send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ server: []*dhcpv4.DHCPv4{
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0, 0, 0, 0}),
+ },
+ want: nil, // Explicitly empty.
+ wantErr: ErrNoResponse,
+ },
+ {
+ desc: "no response, timeout",
+ send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ wantErr: ErrNoResponse,
+ },
+ } {
+ t.Run(tt.desc, func(t *testing.T) {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{tt.server},
+ // Use an unbuffered channel to make sure we
+ // have no deadlocks.
+ withBufferCap(0))
+ defer mc.Close()
+
+ rcvd, err := mc.SendAndRead(context.Background(), DefaultServers, tt.send, nil)
+ if err != tt.wantErr {
+ t.Error(err)
+ }
+
+ if err := ComparePacket(rcvd, tt.want); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ })
+ }
+}
+
+func TestParallelSendAndRead(t *testing.T) {
+ pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
+
+ // Both the server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{},
+ WithTimeout(10*time.Second),
+ // Use an unbuffered channel to make sure nothing blocks.
+ withBufferCap(0))
+ defer mc.Close()
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
+ t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ time.Sleep(4 * time.Second)
+
+ if err := mc.Close(); err != nil {
+ t.Errorf("closing failed: %v", err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestReuseXID(t *testing.T) {
+ pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
+
+ // Both the server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{})
+ defer mc.Close()
+
+ if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
+ t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
+ }
+
+ if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
+ t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
+ }
+}
+
+func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
+ pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
+
+ responses := newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33})
+
+ // Both the server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, udpConn := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{{responses}})
+ defer mc.Close()
+
+ // Too short for valid DHCPv4 packet.
+ udpConn.WriteTo([]byte{0x01}, nil)
+ udpConn.WriteTo([]byte{0x01, 0x2}, nil)
+
+ rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil)
+ if err != nil {
+ t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err)
+ }
+
+ if err := ComparePacket(rcvd, responses); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+}
+
+func TestMultipleSendAndRead(t *testing.T) {
+ for _, tt := range []struct {
+ desc string
+ send []*dhcpv4.DHCPv4
+ server [][]*dhcpv4.DHCPv4
+ wantErr []error
+ }{
+ {
+ desc: "two requests, two responses",
+ send: []*dhcpv4.DHCPv4{
+ newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x44, 0x44, 0x44, 0x44}),
+ },
+ server: [][]*dhcpv4.DHCPv4{
+ []*dhcpv4.DHCPv4{ // Response for first packet.
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
+ },
+ []*dhcpv4.DHCPv4{ // Response for second packet.
+ newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x44, 0x44, 0x44, 0x44}),
+ },
+ },
+ wantErr: []error{
+ nil,
+ nil,
+ },
+ },
+ } {
+ // Both server and client only get 2 seconds.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ mc, _ := serveAndClient(ctx, tt.server)
+ defer mc.Close()
+
+ for i, send := range tt.send {
+ ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ rcvd, err := mc.SendAndRead(ctx, DefaultServers, send, nil)
+
+ if wantErr := tt.wantErr[i]; err != wantErr {
+ t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr)
+ }
+ if err := pktsExpected([]*dhcpv4.DHCPv4{rcvd}, tt.server[i]); err != nil {
+ t.Errorf("got unexpected packets: %v", err)
+ }
+ }
+ }
+}
diff --git a/dhcpv4/nclient4/conn_linux.go b/dhcpv4/nclient4/conn_linux.go
new file mode 100644
index 0000000..00c8a32
--- /dev/null
+++ b/dhcpv4/nclient4/conn_linux.go
@@ -0,0 +1,173 @@
+// Copyright 2018 the u-root Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+
+package nclient4
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "os"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/mdlayher/ethernet"
+ "github.com/mdlayher/raw"
+ "github.com/u-root/u-root/pkg/uio"
+ "golang.org/x/sys/unix"
+)
+
+var (
+ // BroadcastMac is the broadcast MAC address.
+ //
+ // Any UDP packet sent to this address is broadcast on the subnet.
+ BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255})
+)
+
+// NewIPv4UDPConn returns a UDP connection bound to both the interface and port
+// given based on a IPv4 DGRAM socket. The UDP connection allows broadcasting.
+//
+// The interface must already be configured.
+func NewIPv4UDPConn(iface string, port int) (net.PacketConn, error) {
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
+ if err != nil {
+ return nil, fmt.Errorf("cannot get a UDP socket: %v", err)
+ }
+ f := os.NewFile(uintptr(fd), "")
+ // net.FilePacketConn dups the FD, so we have to close this in any case.
+ defer f.Close()
+
+ // Allow broadcasting.
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1); err != nil {
+ return nil, fmt.Errorf("cannot set broadcasting on socket: %v", err)
+ }
+ // Allow reusing the addr to aid debugging.
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
+ return nil, fmt.Errorf("cannot set reuseaddr on socket: %v", err)
+ }
+ if len(iface) != 0 {
+ // Bind directly to the interface.
+ if err := dhcpv4.BindToInterface(fd, iface); err != nil {
+ return nil, fmt.Errorf("cannot bind to interface %s: %v", iface, err)
+ }
+ }
+ // Bind to the port.
+ if err := unix.Bind(fd, &unix.SockaddrInet4{Port: port}); err != nil {
+ return nil, fmt.Errorf("cannot bind to port %d: %v", port, err)
+ }
+
+ return net.FilePacketConn(f)
+}
+
+// NewRawUDPConn returns a UDP connection bound to the interface and port
+// given based on a raw packet socket. All packets are broadcasted.
+//
+// The interface can be completely unconfigured.
+func NewRawUDPConn(iface string, port int) (net.PacketConn, error) {
+ ifc, err := net.InterfaceByName(iface)
+ if err != nil {
+ return nil, err
+ }
+ rawConn, err := raw.ListenPacket(ifc, uint16(ethernet.EtherTypeIPv4), &raw.Config{LinuxSockDGRAM: true})
+ if err != nil {
+ return nil, err
+ }
+ return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil
+}
+
+// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast
+// MAC address.
+type BroadcastRawUDPConn struct {
+ // PacketConn is a raw DGRAM socket.
+ net.PacketConn
+
+ // boundAddr is the address this RawUDPConn is "bound" to.
+ //
+ // Calls to ReadFrom will only return packets destined to this address.
+ boundAddr *net.UDPAddr
+}
+
+// NewBroadcastUDPConn returns a PacketConn that marshals and unmarshals UDP
+// packets, sending them to the broadcast MAC at on rawPacketConn.
+//
+// Calls to ReadFrom will only return packets destined to boundAddr.
+func NewBroadcastUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn {
+ return &BroadcastRawUDPConn{
+ PacketConn: rawPacketConn,
+ boundAddr: boundAddr,
+ }
+}
+
+func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool {
+ if bound == nil {
+ return true
+ }
+ if bound.IP != nil && !bound.IP.Equal(addr.IP) {
+ return false
+ }
+ return bound.Port == addr.Port
+}
+
+// ReadFrom implements net.PacketConn.ReadFrom.
+//
+// ReadFrom reads raw IP packets and will try to match them against
+// upc.boundAddr. Any matching packets are returned via the given buffer.
+func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
+ ipLen := IPv4MaximumHeaderSize
+ udpLen := UDPMinimumSize
+
+ for {
+ pkt := make([]byte, ipLen+udpLen+len(b))
+ n, _, err := upc.PacketConn.ReadFrom(pkt)
+ if err != nil {
+ return 0, nil, err
+ }
+ if n == 0 {
+ return 0, nil, io.EOF
+ }
+ pkt = pkt[:n]
+ buf := uio.NewBigEndianBuffer(pkt)
+
+ // To read the header length, access data directly.
+ ipHdr := IPv4(buf.Data())
+ ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength())))
+
+ if ipHdr.TransportProtocol() != UDPProtocolNumber {
+ continue
+ }
+ udpHdr := UDP(buf.Consume(udpLen))
+
+ addr := &net.UDPAddr{
+ IP: net.IP(ipHdr.DestinationAddress()),
+ Port: int(udpHdr.DestinationPort()),
+ }
+ if !udpMatch(addr, upc.boundAddr) {
+ continue
+ }
+ srcAddr := &net.UDPAddr{
+ IP: net.IP(ipHdr.SourceAddress()),
+ Port: int(udpHdr.SourcePort()),
+ }
+ return copy(b, buf.ReadAll()), srcAddr, nil
+ }
+}
+
+// WriteTo implements net.PacketConn.WriteTo and broadcasts all packets at the
+// raw socket level.
+//
+// WriteTo wraps the given packet in the appropriate UDP and IP header before
+// sending it on the packet conn.
+func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+ udpAddr, ok := addr.(*net.UDPAddr)
+ if !ok {
+ return 0, fmt.Errorf("must supply UDPAddr")
+ }
+
+ // Using the boundAddr is not quite right here, but it works.
+ packet := udp4pkt(b, udpAddr, upc.boundAddr)
+
+ // Broadcasting is not always right, but hell, what the ARP do I know.
+ return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac})
+}
diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go
new file mode 100644
index 0000000..81ba837
--- /dev/null
+++ b/dhcpv4/nclient4/ipv4.go
@@ -0,0 +1,376 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// This file contains code taken from gVisor.
+
+// +build go1.12
+
+package nclient4
+
+import (
+ "encoding/binary"
+ "net"
+
+ "github.com/u-root/u-root/pkg/uio"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr net.IP
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr net.IP
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4Version is the version of the ipv4 protocol.
+ IPv4Version = 4
+)
+
+var (
+ // IPv4Broadcast is the broadcast address of the IPv4 procotol.
+ IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff}
+
+ // IPv4Any is the non-routable IPv4 "any" meta address.
+ IPv4Any = net.IP{0, 0, 0, 0}
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() net.IP {
+ return net.IP(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() net.IP {
+ return net.IP(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() TransportProtocolNumber {
+ return TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[totalLen:])
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr net.IP) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4())
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr net.IP) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4())
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the total
+// length of the packet and the checksum of the network-layer pseudo-header
+// (excluding the total length) and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], checksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
+
+func calculateChecksum(buf []byte, initial uint32) uint16 {
+ v := initial
+
+ l := len(buf)
+ if l&1 != 0 {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16))
+}
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func Checksum(buf []byte, initial uint16) uint16 {
+ return calculateChecksum(buf, uint32(initial))
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+//
+// Note that checksum a must have been computed on an even number of bytes.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the
+// given destination protocol and network address, ignoring the length
+// field. Pseudo-headers are needed by transport layers when calculating
+// their own checksum.
+func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
+
+func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte {
+ ipLen := IPv4MinimumSize
+ udpLen := UDPMinimumSize
+
+ h := make([]byte, 0, ipLen+udpLen+len(packet))
+ hdr := uio.NewBigEndianBuffer(h)
+
+ ipv4fields := &IPv4Fields{
+ IHL: IPv4MinimumSize,
+ TotalLength: uint16(ipLen + udpLen + len(packet)),
+ TTL: 30,
+ Protocol: uint8(UDPProtocolNumber),
+ SrcAddr: src.IP.To4(),
+ DstAddr: dest.IP.To4(),
+ }
+ ipv4hdr := IPv4(hdr.WriteN(ipLen))
+ ipv4hdr.Encode(ipv4fields)
+ ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum())
+
+ udphdr := UDP(hdr.WriteN(udpLen))
+ udphdr.Encode(&UDPFields{
+ SrcPort: uint16(src.Port),
+ DstPort: uint16(dest.Port),
+ Length: uint16(udpLen + len(packet)),
+ })
+
+ xsum := Checksum(packet, PseudoHeaderChecksum(
+ ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
+ udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length()))
+
+ hdr.WriteBytes(packet)
+ return hdr.Data()
+}
diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go
index 5ef4479..1ccd5f4 100644
--- a/dhcpv4/server4/server.go
+++ b/dhcpv4/server4/server.go
@@ -1,11 +1,9 @@
package server4
import (
- "fmt"
"log"
"net"
"sync"
- "time"
"github.com/insomniacslk/dhcp/dhcpv4"
)
@@ -36,7 +34,7 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
)
-func handler(conn net.PacketConn, peer net.Addr, m dhcpv4.DHCPv4) {
+func handler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
// this function will just print the received DHCPv4 message, without replying
log.Print(m.Summary())
}
@@ -46,12 +44,14 @@ func main() {
IP: net.ParseIP("127.0.0.1"),
Port: 67,
}
- server := dhcpv4.NewServer(laddr, handler)
-
- defer server.Close()
- if err := server.ActivateAndServe(); err != nil {
- log.Panic(err)
+ server, err := dhcpv4.NewServer(laddr, handler)
+ if err != nil {
+ log.Fatal(err)
}
+
+ // This never returns. If you want to do other stuff, dump it into a
+ // goroutine.
+ server.Serve()
}
*/
@@ -66,98 +66,61 @@ type Server struct {
connMutex sync.Mutex
shouldStop chan bool
Handler Handler
- localAddr net.UDPAddr
-}
-
-// LocalAddr returns the local address of the listening socket, or nil if not
-// listening
-func (s *Server) LocalAddr() net.Addr {
- s.connMutex.Lock()
- defer s.connMutex.Unlock()
- if s.conn == nil {
- return nil
- }
- return s.conn.LocalAddr()
}
-// ActivateAndServe starts the DHCPv4 server. The listener will run in
-// background, and can be interrupted with `Server.Close`.
-func (s *Server) ActivateAndServe() error {
- s.connMutex.Lock()
- if s.conn != nil {
- // this may panic if s.conn is closed but not reset properly. For that
- // you should use `Server.Close`.
- s.Close()
- }
- conn, err := net.ListenUDP("udp4", &s.localAddr)
- if err != nil {
- s.connMutex.Unlock()
- return err
- }
- s.conn = conn
- s.connMutex.Unlock()
- var (
- pc *net.UDPConn
- ok bool
- )
- if pc, ok = s.conn.(*net.UDPConn); !ok {
- return fmt.Errorf("error: not an UDPConn")
- }
- if pc == nil {
- return fmt.Errorf("ActivateAndServe: invalid nil PacketConn")
- }
- log.Printf("Server listening on %s", pc.LocalAddr())
+// Serve serves requests.
+func (s *Server) Serve() {
+ log.Printf("Server listening on %s", s.conn.LocalAddr())
log.Print("Ready to handle requests")
for {
- select {
- case <-s.shouldStop:
- break
- case <-time.After(time.Millisecond):
- }
- pc.SetReadDeadline(time.Now().Add(time.Second))
rbuf := make([]byte, 4096) // FIXME this is bad
- n, peer, err := pc.ReadFrom(rbuf)
+ n, peer, err := s.conn.ReadFrom(rbuf)
if err != nil {
- switch err.(type) {
- case net.Error:
- if !err.(net.Error).Timeout() {
- return err
- }
- // if timeout, silently skip and continue
- default:
- // complain and continue
- log.Printf("Error reading from packet conn: %v", err)
- }
- continue
+ log.Printf("Error reading from packet conn: %v", err)
+ return
}
log.Printf("Handling request from %v", peer)
+
m, err := dhcpv4.FromBytes(rbuf[:n])
if err != nil {
log.Printf("Error parsing DHCPv4 request: %v", err)
continue
}
- go s.Handler(pc, peer, m)
+ go s.Handler(s.conn, peer, m)
}
}
// Close sends a termination request to the server, and closes the UDP listener
func (s *Server) Close() error {
- s.shouldStop <- true
- s.connMutex.Lock()
- defer s.connMutex.Unlock()
- if s.conn != nil {
- ret := s.conn.Close()
- s.conn = nil
- return ret
+ return s.conn.Close()
+}
+
+// ServerOpt adds optional configuration to a server.
+type ServerOpt func(s *Server)
+
+// WithConn configures the server with the given connection.
+func WithConn(c net.PacketConn) ServerOpt {
+ return func(s *Server) {
+ s.conn = c
}
- return nil
}
// NewServer initializes and returns a new Server object
-func NewServer(addr net.UDPAddr, handler Handler) *Server {
- return &Server{
- localAddr: addr,
+func NewServer(addr *net.UDPAddr, handler Handler, opt ...ServerOpt) (*Server, error) {
+ s := &Server{
Handler: handler,
shouldStop: make(chan bool, 1),
}
+
+ for _, o := range opt {
+ o(s)
+ }
+ if s.conn == nil {
+ var err error
+ s.conn, err = net.ListenUDP("udp4", addr)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return s, nil
}
diff --git a/dhcpv4/server4/server_test.go b/dhcpv4/server4/server_test.go
index f6ce18e..a6895d2 100644
--- a/dhcpv4/server4/server_test.go
+++ b/dhcpv4/server4/server_test.go
@@ -1,8 +1,9 @@
-// +build integration
+// +build go1.12
package server4
import (
+ "context"
"log"
"math/rand"
"net"
@@ -10,7 +11,7 @@ import (
"time"
"github.com/insomniacslk/dhcp/dhcpv4"
- "github.com/insomniacslk/dhcp/dhcpv4/client4"
+ "github.com/insomniacslk/dhcp/dhcpv4/nclient4"
"github.com/insomniacslk/dhcp/interfaces"
"github.com/stretchr/testify/require"
)
@@ -61,68 +62,58 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
// utility function to set up a client and a server instance and run it in
// background. The caller needs to call Server.Close() once finished.
-func setUpClientAndServer(handler Handler) (*client4.Client, *Server) {
+func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*nclient4.Client, *Server) {
// strong assumption, I know
loAddr := net.ParseIP("127.0.0.1")
- laddr := net.UDPAddr{
+ saddr := net.UDPAddr{
IP: loAddr,
Port: randPort(),
}
- s := NewServer(laddr, handler)
- go s.ActivateAndServe()
-
- c := client4.NewClient()
- // FIXME this doesn't deal well with raw sockets, the actual 0 will be used
- // in the UDP header as source port
- c.LocalAddr = &net.UDPAddr{IP: loAddr, Port: randPort()}
- for {
- if s.LocalAddr() != nil {
- break
- }
- time.Sleep(10 * time.Millisecond)
- log.Printf("Waiting for server to run...")
+ caddr := net.UDPAddr{
+ IP: loAddr,
+ Port: randPort(),
}
- c.RemoteAddr = s.LocalAddr()
- log.Printf("Client.RemoteAddr: %s", c.RemoteAddr)
-
- return c, s
-}
-
-func TestNewServer(t *testing.T) {
- laddr := net.UDPAddr{
- IP: net.ParseIP("127.0.0.1"),
- Port: 0,
+ s, err := NewServer(&saddr, handler)
+ if err != nil {
+ t.Fatal(err)
}
- s := NewServer(laddr, DORAHandler)
- defer s.Close()
+ go s.Serve()
- require.NotNil(t, s)
- require.Nil(t, s.conn)
- require.Equal(t, laddr, s.localAddr)
- require.NotNil(t, s.Handler)
+ clientConn, err := nclient4.NewIPv4UDPConn("", caddr.Port)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c := nclient4.NewWithConn(clientConn, iface.HardwareAddr, nclient4.WithServerAddr(&saddr))
+ return c, s
}
-func TestServerActivateAndServe(t *testing.T) {
- c, s := setUpClientAndServer(DORAHandler)
- defer s.Close()
-
+func TestServer(t *testing.T) {
ifaces, err := interfaces.GetLoopbackInterfaces()
require.NoError(t, err)
require.NotEqual(t, 0, len(ifaces))
+ // lo has a HardwareAddr of "nil". The client will drop all packets
+ // that don't match the HWAddr of the client interface.
+ hwaddr := net.HardwareAddr{1, 2, 3, 4, 5, 6}
+ ifaces[0].HardwareAddr = hwaddr
+
+ c, s := setUpClientAndServer(t, ifaces[0], DORAHandler)
+ defer func() {
+ require.Nil(t, s.Close())
+ }()
+
xid := dhcpv4.TransactionID{0xaa, 0xbb, 0xcc, 0xdd}
- hwaddr := net.HardwareAddr{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
modifiers := []dhcpv4.Modifier{
dhcpv4.WithTransactionID(xid),
- dhcpv4.WithHwAddr(hwaddr),
+ dhcpv4.WithHwAddr(ifaces[0].HardwareAddr),
}
- conv, err := c.Exchange(ifaces[0].Name, modifiers...)
+ offer, ack, err := c.Request(context.Background(), modifiers...)
require.NoError(t, err)
- require.Equal(t, 4, len(conv))
- for _, p := range conv {
+ require.NotNil(t, offer, ack)
+ for _, p := range []*dhcpv4.DHCPv4{offer, ack} {
require.Equal(t, xid, p.TransactionID)
- require.Equal(t, hwaddr, p.ClientHWAddr)
+ require.Equal(t, ifaces[0].HardwareAddr, p.ClientHWAddr)
}
}