diff options
Diffstat (limited to 'pkg/dhcp/client.go')
-rw-r--r-- | pkg/dhcp/client.go | 274 |
1 files changed, 274 insertions, 0 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go new file mode 100644 index 000000000..9a4fd7ae4 --- /dev/null +++ b/pkg/dhcp/client.go @@ -0,0 +1,274 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dhcp + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Client is a DHCP client. +type Client struct { + stack *stack.Stack + nicid tcpip.NICID + linkAddr tcpip.LinkAddress + + mu sync.Mutex + addr tcpip.Address + cfg Config + lease time.Duration + cancelRenew func() +} + +// NewClient creates a DHCP client. +// +// TODO: add s.LinkAddr(nicid) to *stack.Stack. +func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client { + return &Client{ + stack: s, + nicid: nicid, + linkAddr: linkAddr, + } +} + +// Start starts the DHCP client. +// It will periodically search for an IP address using the Request method. +func (c *Client) Start() { + go func() { + for { + log.Print("DHCP request") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := c.Request(ctx, "") + cancel() + if err == nil { + break + } + } + log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength) + }() +} + +// Address reports the IP address acquired by the DHCP client. +func (c *Client) Address() tcpip.Address { + c.mu.Lock() + defer c.mu.Unlock() + return c.addr +} + +// Config reports the DHCP configuration acquired with the IP address lease. +func (c *Client) Config() Config { + c.mu.Lock() + defer c.mu.Unlock() + return c.cfg +} + +// Shutdown relinquishes any lease and ends any outstanding renewal timers. +func (c *Client) Shutdown() { + c.mu.Lock() + defer c.mu.Unlock() + if c.addr != "" { + c.stack.RemoveAddress(c.nicid, c.addr) + } + if c.cancelRenew != nil { + c.cancelRenew() + } +} + +// Request executes a DHCP request session. +// +// On success, it adds a new address to this client's TCPIP stack. +// If the server sets a lease limit a timer is set to automatically +// renew it. +func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error { + var wq waiter.Queue + ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + return fmt.Errorf("dhcp: outbound endpoint: %v", err) + } + err = ep.Bind(tcpip.FullAddress{ + Addr: "\x00\x00\x00\x00", + Port: clientPort, + }, nil) + defer ep.Close() + if err != nil { + return fmt.Errorf("dhcp: connect failed: %v", err) + } + + epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + return fmt.Errorf("dhcp: inbound endpoint: %v", err) + } + err = epin.Bind(tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: clientPort, + }, nil) + defer epin.Close() + if err != nil { + return fmt.Errorf("dhcp: connect failed: %v", err) + } + + var xid [4]byte + rand.Read(xid[:]) + + // DHCPDISCOVERY + options := options{ + {optDHCPMsgType, []byte{byte(dhcpDISCOVER)}}, + {optParamReq, []byte{ + 1, // request subnet mask + 3, // request router + 15, // domain name + 6, // domain name server + }}, + } + if requestedAddr != "" { + options = append(options, option{optReqIPAddr, []byte(requestedAddr)}) + } + h := make(header, headerBaseSize+options.len()) + h.init() + h.setOp(opRequest) + copy(h.xidbytes(), xid[:]) + h.setBroadcast() + copy(h.chaddr(), c.linkAddr) + h.setOptions(options) + + serverAddr := &tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: serverPort, + } + wopts := tcpip.WriteOptions{ + To: serverAddr, + } + if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { + return fmt.Errorf("dhcp discovery write: %v", err) + } + + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + // DHCPOFFER + for { + var addr tcpip.FullAddress + v, err := epin.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) + } + } + h = header(v) + if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { + break + } + } + if _, err := h.options(); err != nil { + return fmt.Errorf("dhcp offer: %v", err) + } + + var ack bool + var cfg Config + + // DHCPREQUEST + addr := tcpip.Address(h.yiaddr()) + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil { + if err != tcpip.ErrDuplicateAddress { + return fmt.Errorf("adding address: %v", err) + } + } + defer func() { + if ack { + c.mu.Lock() + c.addr = addr + c.cfg = cfg + c.mu.Unlock() + } else { + c.stack.RemoveAddress(c.nicid, addr) + } + }() + h.setOp(opRequest) + for i, b := 0, h.yiaddr(); i < len(b); i++ { + b[i] = 0 + } + h.setOptions([]option{ + {optDHCPMsgType, []byte{byte(dhcpREQUEST)}}, + {optReqIPAddr, []byte(addr)}, + {optDHCPServer, h.siaddr()}, + }) + if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { + return fmt.Errorf("dhcp discovery write: %v", err) + } + + // DHCPACK + for { + var addr tcpip.FullAddress + v, err := epin.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) + } + } + h = header(v) + if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { + break + } + } + opts, e := h.options() + if e != nil { + return fmt.Errorf("dhcp ack: %v", e) + } + if err := cfg.decode(opts); err != nil { + return fmt.Errorf("dhcp ack bad options: %v", err) + } + msgtype, e := opts.dhcpMsgType() + if e != nil { + return fmt.Errorf("dhcp ack: %v", e) + } + ack = msgtype == dhcpACK + if !ack { + return fmt.Errorf("dhcp: request not acknowledged") + } + if cfg.LeaseLength != 0 { + go c.renewAfter(cfg.LeaseLength) + } + return nil +} + +func (c *Client) renewAfter(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + if c.cancelRenew != nil { + c.cancelRenew() + } + ctx, cancel := context.WithCancel(context.Background()) + c.cancelRenew = cancel + go func() { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + if err := c.Request(ctx, c.addr); err != nil { + log.Printf("address renewal failed: %v", err) + go c.renewAfter(1 * time.Minute) + } + } + }() +} |