summaryrefslogtreecommitdiffhomepage
path: root/pkg/dhcp/client.go
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@gmail.com>2018-08-25 06:16:34 -0700
committerShentubot <shentubot@google.com>2018-08-25 06:17:32 -0700
commitb17e80ef5a44e773e9032e7dbcb7438ff851ab7c (patch)
tree5c40dd5e44d70c51c2089ec10b51bd480fb8be50 /pkg/dhcp/client.go
parent106de2182d34197d76fb68863cd4a102ebac2dbb (diff)
Upstreaming DHCP changes from Fuchsia
PiperOrigin-RevId: 210221388 Change-Id: Ic82d592b8c4778855fa55ba913f6b9a10b2d511f
Diffstat (limited to 'pkg/dhcp/client.go')
-rw-r--r--pkg/dhcp/client.go285
1 files changed, 174 insertions, 111 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go
index 8b5fc0452..909040e79 100644
--- a/pkg/dhcp/client.go
+++ b/pkg/dhcp/client.go
@@ -18,7 +18,6 @@ import (
"bytes"
"context"
"fmt"
- "log"
"sync"
"time"
@@ -32,9 +31,10 @@ import (
// Client is a DHCP client.
type Client struct {
- stack *stack.Stack
- nicid tcpip.NICID
- linkAddr tcpip.LinkAddress
+ stack *stack.Stack
+ nicid tcpip.NICID
+ linkAddr tcpip.LinkAddress
+ acquiredFunc func(old, new tcpip.Address, cfg Config)
mu sync.Mutex
addr tcpip.Address
@@ -46,29 +46,57 @@ type Client struct {
// NewClient creates a DHCP client.
//
// TODO: add s.LinkAddr(nicid) to *stack.Stack.
-func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
+func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client {
return &Client{
- stack: s,
- nicid: nicid,
- linkAddr: linkAddr,
+ stack: s,
+ nicid: nicid,
+ linkAddr: linkAddr,
+ acquiredFunc: acquiredFunc,
}
}
-// Start starts the DHCP client.
+// Run starts the DHCP client.
// It will periodically search for an IP address using the Request method.
-func (c *Client) Start() {
- go func() {
- for {
- log.Print("DHCP request")
- ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- err := c.Request(ctx, "")
- cancel()
- if err == nil {
- break
- }
+func (c *Client) Run(ctx context.Context) {
+ go c.run(ctx)
+}
+
+func (c *Client) run(ctx context.Context) {
+ defer func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.addr != "" {
+ c.stack.RemoveAddress(c.nicid, c.addr)
}
- log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength)
}()
+
+ var renewAddr tcpip.Address
+ for {
+ reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ cfg, err := c.Request(reqCtx, renewAddr)
+ cancel()
+ if err != nil {
+ select {
+ case <-time.After(1 * time.Second):
+ // loop and try again
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ c.mu.Lock()
+ renewAddr = c.addr
+ c.mu.Unlock()
+
+ timer := time.NewTimer(cfg.LeaseLength)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // loop and make a renewal request
+ }
+ }
}
// Address reports the IP address acquired by the DHCP client.
@@ -85,56 +113,53 @@ func (c *Client) Config() Config {
return c.cfg
}
-// Shutdown relinquishes any lease and ends any outstanding renewal timers.
-func (c *Client) Shutdown() {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.addr != "" {
- c.stack.RemoveAddress(c.nicid, c.addr)
- }
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
-}
-
// Request executes a DHCP request session.
//
// On success, it adds a new address to this client's TCPIP stack.
// If the server sets a lease limit a timer is set to automatically
// renew it.
-func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
+func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) {
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
var wq waiter.Queue
ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: outbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
}
- err = ep.Bind(tcpip.FullAddress{
- Addr: "\x00\x00\x00\x00",
- Port: clientPort,
- }, nil)
defer ep.Close()
- if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ if err := ep.Bind(tcpip.FullAddress{
+ Addr: "\x00\x00\x00\x00",
+ Port: ClientPort,
+ NIC: c.nicid,
+ }, nil); err != nil {
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: inbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
}
- err = epin.Bind(tcpip.FullAddress{
- Addr: "\xff\xff\xff\xff",
- Port: clientPort,
- }, nil)
defer epin.Close()
- if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ if err := epin.Bind(tcpip.FullAddress{
+ Addr: "\xff\xff\xff\xff",
+ Port: ClientPort,
+ NIC: c.nicid,
+ }, nil); err != nil {
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
var xid [4]byte
rand.Read(xid[:])
// DHCPDISCOVERY
- options := options{
+ discOpts := options{
{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
{optParamReq, []byte{
1, // request subnet mask
@@ -144,25 +169,34 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
}},
}
if requestedAddr != "" {
- options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
+ discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)})
}
- h := make(header, headerBaseSize+options.len())
+ var clientID []byte
+ if len(c.linkAddr) == 6 {
+ clientID = append(
+ []byte{1}, // RFC 1700: Hardware Type [Ethernet = 1]
+ c.linkAddr...,
+ )
+ discOpts = append(discOpts, option{optClientID, clientID})
+ }
+ h := make(header, headerBaseSize+discOpts.len()+1)
h.init()
h.setOp(opRequest)
copy(h.xidbytes(), xid[:])
h.setBroadcast()
copy(h.chaddr(), c.linkAddr)
- h.setOptions(options)
+ h.setOptions(discOpts)
serverAddr := &tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
- Port: serverPort,
+ Port: ServerPort,
+ NIC: c.nicid,
}
wopts := tcpip.WriteOptions{
To: serverAddr,
}
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -170,6 +204,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
defer wq.EventUnregister(&we)
// DHCPOFFER
+ var opts options
for {
var addr tcpip.FullAddress
v, _, err := epin.Read(&addr)
@@ -178,49 +213,84 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpOFFER, xid[:])
+ if !valid {
+ if e != nil {
+ // TODO: handle all the errors?
+ // TODO: report malformed server responses
+ }
+ continue
}
- }
- if _, err := h.options(); err != nil {
- return fmt.Errorf("dhcp offer: %v", err)
+ break
}
var ack bool
- var cfg Config
+ if err := cfg.decode(opts); err != nil {
+ return Config{}, fmt.Errorf("dhcp offer: %v", err)
+ }
// DHCPREQUEST
addr := tcpip.Address(h.yiaddr())
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
if err != tcpip.ErrDuplicateAddress {
- return fmt.Errorf("adding address: %v", err)
+ return Config{}, fmt.Errorf("adding address: %v", err)
}
}
defer func() {
- if ack {
- c.mu.Lock()
- c.addr = addr
- c.cfg = cfg
- c.mu.Unlock()
- } else {
+ if !ack || reterr != nil {
c.stack.RemoveAddress(c.nicid, addr)
+ addr = ""
+ cfg = Config{Error: reterr}
+ }
+
+ c.mu.Lock()
+ oldAddr := c.addr
+ c.addr = addr
+ c.cfg = cfg
+ c.mu.Unlock()
+
+ // Clean up broadcast addresses before calling acquiredFunc
+ // so nothing else uses them by mistake.
+ //
+ // (The deferred RemoveAddress calls above silently error.)
+ c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
+ if c.acquiredFunc != nil {
+ c.acquiredFunc(oldAddr, addr, cfg)
+ }
+ if requestedAddr != "" && requestedAddr != addr {
+ c.stack.RemoveAddress(c.nicid, requestedAddr)
}
}()
+ h.init()
h.setOp(opRequest)
for i, b := 0, h.yiaddr(); i < len(b); i++ {
b[i] = 0
}
- h.setOptions([]option{
+ for i, b := 0, h.siaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ for i, b := 0, h.giaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ reqOpts := []option{
{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
{optReqIPAddr, []byte(addr)},
- {optDHCPServer, h.siaddr()},
- })
+ {optDHCPServer, []byte(cfg.ServerAddress)},
+ }
+ if len(clientID) != 0 {
+ reqOpts = append(reqOpts, option{optClientID, clientID})
+ }
+ h.setOptions(reqOpts)
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
// DHCPACK
@@ -232,53 +302,46 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpACK, xid[:])
+ if !valid {
+ if e != nil {
+ // TODO: handle all the errors?
+ // TODO: report malformed server responses
+ }
+ if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
+ if msg := opts.message(); msg != "" {
+ return Config{}, fmt.Errorf("dhcp: NAK %q", msg)
+ }
+ return Config{}, fmt.Errorf("dhcp: NAK with no message")
+ }
+ continue
}
+ break
}
- opts, e := h.options()
- if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
- }
- if err := cfg.decode(opts); err != nil {
- return fmt.Errorf("dhcp ack bad options: %v", err)
+ ack = true
+ return cfg, nil
+}
+
+func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
+ if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
+ return nil, false, nil
}
- msgtype, e := opts.dhcpMsgType()
- if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
+ opts, err = h.options()
+ if err != nil {
+ return nil, false, err
}
- ack = msgtype == dhcpACK
- if !ack {
- return fmt.Errorf("dhcp: request not acknowledged")
+ msgtype, err := opts.dhcpMsgType()
+ if err != nil {
+ return nil, false, err
}
- if cfg.LeaseLength != 0 {
- go c.renewAfter(cfg.LeaseLength)
+ if msgtype != typ {
+ return nil, false, nil
}
- return nil
-}
-
-func (c *Client) renewAfter(d time.Duration) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
- ctx, cancel := context.WithCancel(context.Background())
- c.cancelRenew = cancel
- go func() {
- timer := time.NewTimer(d)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- case <-timer.C:
- if err := c.Request(ctx, c.addr); err != nil {
- log.Printf("address renewal failed: %v", err)
- go c.renewAfter(1 * time.Minute)
- }
- }
- }()
+ return opts, true, nil
}