diff options
author | Tamir Duberstein <tamird@gmail.com> | 2018-08-25 06:16:34 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-08-25 06:17:32 -0700 |
commit | b17e80ef5a44e773e9032e7dbcb7438ff851ab7c (patch) | |
tree | 5c40dd5e44d70c51c2089ec10b51bd480fb8be50 /pkg/dhcp/client.go | |
parent | 106de2182d34197d76fb68863cd4a102ebac2dbb (diff) |
Upstreaming DHCP changes from Fuchsia
PiperOrigin-RevId: 210221388
Change-Id: Ic82d592b8c4778855fa55ba913f6b9a10b2d511f
Diffstat (limited to 'pkg/dhcp/client.go')
-rw-r--r-- | pkg/dhcp/client.go | 285 |
1 files changed, 174 insertions, 111 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go index 8b5fc0452..909040e79 100644 --- a/pkg/dhcp/client.go +++ b/pkg/dhcp/client.go @@ -18,7 +18,6 @@ import ( "bytes" "context" "fmt" - "log" "sync" "time" @@ -32,9 +31,10 @@ import ( // Client is a DHCP client. type Client struct { - stack *stack.Stack - nicid tcpip.NICID - linkAddr tcpip.LinkAddress + stack *stack.Stack + nicid tcpip.NICID + linkAddr tcpip.LinkAddress + acquiredFunc func(old, new tcpip.Address, cfg Config) mu sync.Mutex addr tcpip.Address @@ -46,29 +46,57 @@ type Client struct { // NewClient creates a DHCP client. // // TODO: add s.LinkAddr(nicid) to *stack.Stack. -func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client { +func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client { return &Client{ - stack: s, - nicid: nicid, - linkAddr: linkAddr, + stack: s, + nicid: nicid, + linkAddr: linkAddr, + acquiredFunc: acquiredFunc, } } -// Start starts the DHCP client. +// Run starts the DHCP client. // It will periodically search for an IP address using the Request method. -func (c *Client) Start() { - go func() { - for { - log.Print("DHCP request") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - err := c.Request(ctx, "") - cancel() - if err == nil { - break - } +func (c *Client) Run(ctx context.Context) { + go c.run(ctx) +} + +func (c *Client) run(ctx context.Context) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.addr != "" { + c.stack.RemoveAddress(c.nicid, c.addr) } - log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength) }() + + var renewAddr tcpip.Address + for { + reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + cfg, err := c.Request(reqCtx, renewAddr) + cancel() + if err != nil { + select { + case <-time.After(1 * time.Second): + // loop and try again + case <-ctx.Done(): + return + } + } + + c.mu.Lock() + renewAddr = c.addr + c.mu.Unlock() + + timer := time.NewTimer(cfg.LeaseLength) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + // loop and make a renewal request + } + } } // Address reports the IP address acquired by the DHCP client. @@ -85,56 +113,53 @@ func (c *Client) Config() Config { return c.cfg } -// Shutdown relinquishes any lease and ends any outstanding renewal timers. -func (c *Client) Shutdown() { - c.mu.Lock() - defer c.mu.Unlock() - if c.addr != "" { - c.stack.RemoveAddress(c.nicid, c.addr) - } - if c.cancelRenew != nil { - c.cancelRenew() - } -} - // Request executes a DHCP request session. // // On success, it adds a new address to this client's TCPIP stack. // If the server sets a lease limit a timer is set to automatically // renew it. -func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error { +func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) { + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress { + return Config{}, fmt.Errorf("dhcp: %v", err) + } + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress { + return Config{}, fmt.Errorf("dhcp: %v", err) + } + defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff") + defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00") + var wq waiter.Queue ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - return fmt.Errorf("dhcp: outbound endpoint: %v", err) + return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err) } - err = ep.Bind(tcpip.FullAddress{ - Addr: "\x00\x00\x00\x00", - Port: clientPort, - }, nil) defer ep.Close() - if err != nil { - return fmt.Errorf("dhcp: connect failed: %v", err) + if err := ep.Bind(tcpip.FullAddress{ + Addr: "\x00\x00\x00\x00", + Port: ClientPort, + NIC: c.nicid, + }, nil); err != nil { + return Config{}, fmt.Errorf("dhcp: connect failed: %v", err) } epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - return fmt.Errorf("dhcp: inbound endpoint: %v", err) + return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err) } - err = epin.Bind(tcpip.FullAddress{ - Addr: "\xff\xff\xff\xff", - Port: clientPort, - }, nil) defer epin.Close() - if err != nil { - return fmt.Errorf("dhcp: connect failed: %v", err) + if err := epin.Bind(tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: ClientPort, + NIC: c.nicid, + }, nil); err != nil { + return Config{}, fmt.Errorf("dhcp: connect failed: %v", err) } var xid [4]byte rand.Read(xid[:]) // DHCPDISCOVERY - options := options{ + discOpts := options{ {optDHCPMsgType, []byte{byte(dhcpDISCOVER)}}, {optParamReq, []byte{ 1, // request subnet mask @@ -144,25 +169,34 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error }}, } if requestedAddr != "" { - options = append(options, option{optReqIPAddr, []byte(requestedAddr)}) + discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)}) } - h := make(header, headerBaseSize+options.len()) + var clientID []byte + if len(c.linkAddr) == 6 { + clientID = append( + []byte{1}, // RFC 1700: Hardware Type [Ethernet = 1] + c.linkAddr..., + ) + discOpts = append(discOpts, option{optClientID, clientID}) + } + h := make(header, headerBaseSize+discOpts.len()+1) h.init() h.setOp(opRequest) copy(h.xidbytes(), xid[:]) h.setBroadcast() copy(h.chaddr(), c.linkAddr) - h.setOptions(options) + h.setOptions(discOpts) serverAddr := &tcpip.FullAddress{ Addr: "\xff\xff\xff\xff", - Port: serverPort, + Port: ServerPort, + NIC: c.nicid, } wopts := tcpip.WriteOptions{ To: serverAddr, } if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { - return fmt.Errorf("dhcp discovery write: %v", err) + return Config{}, fmt.Errorf("dhcp discovery write: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -170,6 +204,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error defer wq.EventUnregister(&we) // DHCPOFFER + var opts options for { var addr tcpip.FullAddress v, _, err := epin.Read(&addr) @@ -178,49 +213,84 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error case <-ch: continue case <-ctx.Done(): - return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) + return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) } } h = header(v) - if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { - break + var valid bool + var e error + opts, valid, e = loadDHCPReply(h, dhcpOFFER, xid[:]) + if !valid { + if e != nil { + // TODO: handle all the errors? + // TODO: report malformed server responses + } + continue } - } - if _, err := h.options(); err != nil { - return fmt.Errorf("dhcp offer: %v", err) + break } var ack bool - var cfg Config + if err := cfg.decode(opts); err != nil { + return Config{}, fmt.Errorf("dhcp offer: %v", err) + } // DHCPREQUEST addr := tcpip.Address(h.yiaddr()) if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil { if err != tcpip.ErrDuplicateAddress { - return fmt.Errorf("adding address: %v", err) + return Config{}, fmt.Errorf("adding address: %v", err) } } defer func() { - if ack { - c.mu.Lock() - c.addr = addr - c.cfg = cfg - c.mu.Unlock() - } else { + if !ack || reterr != nil { c.stack.RemoveAddress(c.nicid, addr) + addr = "" + cfg = Config{Error: reterr} + } + + c.mu.Lock() + oldAddr := c.addr + c.addr = addr + c.cfg = cfg + c.mu.Unlock() + + // Clean up broadcast addresses before calling acquiredFunc + // so nothing else uses them by mistake. + // + // (The deferred RemoveAddress calls above silently error.) + c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff") + c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00") + + if c.acquiredFunc != nil { + c.acquiredFunc(oldAddr, addr, cfg) + } + if requestedAddr != "" && requestedAddr != addr { + c.stack.RemoveAddress(c.nicid, requestedAddr) } }() + h.init() h.setOp(opRequest) for i, b := 0, h.yiaddr(); i < len(b); i++ { b[i] = 0 } - h.setOptions([]option{ + for i, b := 0, h.siaddr(); i < len(b); i++ { + b[i] = 0 + } + for i, b := 0, h.giaddr(); i < len(b); i++ { + b[i] = 0 + } + reqOpts := []option{ {optDHCPMsgType, []byte{byte(dhcpREQUEST)}}, {optReqIPAddr, []byte(addr)}, - {optDHCPServer, h.siaddr()}, - }) + {optDHCPServer, []byte(cfg.ServerAddress)}, + } + if len(clientID) != 0 { + reqOpts = append(reqOpts, option{optClientID, clientID}) + } + h.setOptions(reqOpts) if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { - return fmt.Errorf("dhcp discovery write: %v", err) + return Config{}, fmt.Errorf("dhcp discovery write: %v", err) } // DHCPACK @@ -232,53 +302,46 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error case <-ch: continue case <-ctx.Done(): - return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) + return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) } } h = header(v) - if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { - break + var valid bool + var e error + opts, valid, e = loadDHCPReply(h, dhcpACK, xid[:]) + if !valid { + if e != nil { + // TODO: handle all the errors? + // TODO: report malformed server responses + } + if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid { + if msg := opts.message(); msg != "" { + return Config{}, fmt.Errorf("dhcp: NAK %q", msg) + } + return Config{}, fmt.Errorf("dhcp: NAK with no message") + } + continue } + break } - opts, e := h.options() - if e != nil { - return fmt.Errorf("dhcp ack: %v", e) - } - if err := cfg.decode(opts); err != nil { - return fmt.Errorf("dhcp ack bad options: %v", err) + ack = true + return cfg, nil +} + +func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) { + if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) { + return nil, false, nil } - msgtype, e := opts.dhcpMsgType() - if e != nil { - return fmt.Errorf("dhcp ack: %v", e) + opts, err = h.options() + if err != nil { + return nil, false, err } - ack = msgtype == dhcpACK - if !ack { - return fmt.Errorf("dhcp: request not acknowledged") + msgtype, err := opts.dhcpMsgType() + if err != nil { + return nil, false, err } - if cfg.LeaseLength != 0 { - go c.renewAfter(cfg.LeaseLength) + if msgtype != typ { + return nil, false, nil } - return nil -} - -func (c *Client) renewAfter(d time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - if c.cancelRenew != nil { - c.cancelRenew() - } - ctx, cancel := context.WithCancel(context.Background()) - c.cancelRenew = cancel - go func() { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - case <-timer.C: - if err := c.Request(ctx, c.addr); err != nil { - log.Printf("address renewal failed: %v", err) - go c.renewAfter(1 * time.Minute) - } - } - }() + return opts, true, nil } |