summaryrefslogtreecommitdiffhomepage
path: root/pkg/dhcp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/dhcp')
-rw-r--r--pkg/dhcp/BUILD36
-rw-r--r--pkg/dhcp/client.go274
-rw-r--r--pkg/dhcp/dhcp.go263
-rw-r--r--pkg/dhcp/dhcp_test.go113
-rw-r--r--pkg/dhcp/server.go289
5 files changed, 975 insertions, 0 deletions
diff --git a/pkg/dhcp/BUILD b/pkg/dhcp/BUILD
new file mode 100644
index 000000000..b40860aac
--- /dev/null
+++ b/pkg/dhcp/BUILD
@@ -0,0 +1,36 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "dhcp",
+ srcs = [
+ "client.go",
+ "dhcp.go",
+ "server.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/dhcp",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "dhcp_test",
+ size = "small",
+ srcs = ["dhcp_test.go"],
+ embed = [":dhcp"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/udp",
+ ],
+)
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go
new file mode 100644
index 000000000..9a4fd7ae4
--- /dev/null
+++ b/pkg/dhcp/client.go
@@ -0,0 +1,274 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dhcp
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "fmt"
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Client is a DHCP client.
+type Client struct {
+ stack *stack.Stack
+ nicid tcpip.NICID
+ linkAddr tcpip.LinkAddress
+
+ mu sync.Mutex
+ addr tcpip.Address
+ cfg Config
+ lease time.Duration
+ cancelRenew func()
+}
+
+// NewClient creates a DHCP client.
+//
+// TODO: add s.LinkAddr(nicid) to *stack.Stack.
+func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
+ return &Client{
+ stack: s,
+ nicid: nicid,
+ linkAddr: linkAddr,
+ }
+}
+
+// Start starts the DHCP client.
+// It will periodically search for an IP address using the Request method.
+func (c *Client) Start() {
+ go func() {
+ for {
+ log.Print("DHCP request")
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ err := c.Request(ctx, "")
+ cancel()
+ if err == nil {
+ break
+ }
+ }
+ log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength)
+ }()
+}
+
+// Address reports the IP address acquired by the DHCP client.
+func (c *Client) Address() tcpip.Address {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.addr
+}
+
+// Config reports the DHCP configuration acquired with the IP address lease.
+func (c *Client) Config() Config {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.cfg
+}
+
+// Shutdown relinquishes any lease and ends any outstanding renewal timers.
+func (c *Client) Shutdown() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.addr != "" {
+ c.stack.RemoveAddress(c.nicid, c.addr)
+ }
+ if c.cancelRenew != nil {
+ c.cancelRenew()
+ }
+}
+
+// Request executes a DHCP request session.
+//
+// On success, it adds a new address to this client's TCPIP stack.
+// If the server sets a lease limit a timer is set to automatically
+// renew it.
+func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
+ var wq waiter.Queue
+ ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ return fmt.Errorf("dhcp: outbound endpoint: %v", err)
+ }
+ err = ep.Bind(tcpip.FullAddress{
+ Addr: "\x00\x00\x00\x00",
+ Port: clientPort,
+ }, nil)
+ defer ep.Close()
+ if err != nil {
+ return fmt.Errorf("dhcp: connect failed: %v", err)
+ }
+
+ epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ return fmt.Errorf("dhcp: inbound endpoint: %v", err)
+ }
+ err = epin.Bind(tcpip.FullAddress{
+ Addr: "\xff\xff\xff\xff",
+ Port: clientPort,
+ }, nil)
+ defer epin.Close()
+ if err != nil {
+ return fmt.Errorf("dhcp: connect failed: %v", err)
+ }
+
+ var xid [4]byte
+ rand.Read(xid[:])
+
+ // DHCPDISCOVERY
+ options := options{
+ {optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
+ {optParamReq, []byte{
+ 1, // request subnet mask
+ 3, // request router
+ 15, // domain name
+ 6, // domain name server
+ }},
+ }
+ if requestedAddr != "" {
+ options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
+ }
+ h := make(header, headerBaseSize+options.len())
+ h.init()
+ h.setOp(opRequest)
+ copy(h.xidbytes(), xid[:])
+ h.setBroadcast()
+ copy(h.chaddr(), c.linkAddr)
+ h.setOptions(options)
+
+ serverAddr := &tcpip.FullAddress{
+ Addr: "\xff\xff\xff\xff",
+ Port: serverPort,
+ }
+ wopts := tcpip.WriteOptions{
+ To: serverAddr,
+ }
+ if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
+ return fmt.Errorf("dhcp discovery write: %v", err)
+ }
+
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ // DHCPOFFER
+ for {
+ var addr tcpip.FullAddress
+ v, err := epin.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-ch:
+ continue
+ case <-ctx.Done():
+ return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
+ }
+ }
+ h = header(v)
+ if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
+ break
+ }
+ }
+ if _, err := h.options(); err != nil {
+ return fmt.Errorf("dhcp offer: %v", err)
+ }
+
+ var ack bool
+ var cfg Config
+
+ // DHCPREQUEST
+ addr := tcpip.Address(h.yiaddr())
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
+ if err != tcpip.ErrDuplicateAddress {
+ return fmt.Errorf("adding address: %v", err)
+ }
+ }
+ defer func() {
+ if ack {
+ c.mu.Lock()
+ c.addr = addr
+ c.cfg = cfg
+ c.mu.Unlock()
+ } else {
+ c.stack.RemoveAddress(c.nicid, addr)
+ }
+ }()
+ h.setOp(opRequest)
+ for i, b := 0, h.yiaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ h.setOptions([]option{
+ {optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
+ {optReqIPAddr, []byte(addr)},
+ {optDHCPServer, h.siaddr()},
+ })
+ if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
+ return fmt.Errorf("dhcp discovery write: %v", err)
+ }
+
+ // DHCPACK
+ for {
+ var addr tcpip.FullAddress
+ v, err := epin.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-ch:
+ continue
+ case <-ctx.Done():
+ return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
+ }
+ }
+ h = header(v)
+ if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
+ break
+ }
+ }
+ opts, e := h.options()
+ if e != nil {
+ return fmt.Errorf("dhcp ack: %v", e)
+ }
+ if err := cfg.decode(opts); err != nil {
+ return fmt.Errorf("dhcp ack bad options: %v", err)
+ }
+ msgtype, e := opts.dhcpMsgType()
+ if e != nil {
+ return fmt.Errorf("dhcp ack: %v", e)
+ }
+ ack = msgtype == dhcpACK
+ if !ack {
+ return fmt.Errorf("dhcp: request not acknowledged")
+ }
+ if cfg.LeaseLength != 0 {
+ go c.renewAfter(cfg.LeaseLength)
+ }
+ return nil
+}
+
+func (c *Client) renewAfter(d time.Duration) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.cancelRenew != nil {
+ c.cancelRenew()
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ c.cancelRenew = cancel
+ go func() {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ case <-timer.C:
+ if err := c.Request(ctx, c.addr); err != nil {
+ log.Printf("address renewal failed: %v", err)
+ go c.renewAfter(1 * time.Minute)
+ }
+ }
+ }()
+}
diff --git a/pkg/dhcp/dhcp.go b/pkg/dhcp/dhcp.go
new file mode 100644
index 000000000..762086853
--- /dev/null
+++ b/pkg/dhcp/dhcp.go
@@ -0,0 +1,263 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dhcp implements a DHCP client and server as described in RFC 2131.
+package dhcp
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// Config is standard DHCP configuration.
+type Config struct {
+ ServerAddress tcpip.Address // address of the server
+ SubnetMask tcpip.AddressMask // client address subnet mask
+ Gateway tcpip.Address // client default gateway
+ DomainNameServer tcpip.Address // client domain name server
+ LeaseLength time.Duration // length of the address lease
+}
+
+func (cfg *Config) decode(opts []option) error {
+ *cfg = Config{}
+ for _, opt := range opts {
+ b := opt.body
+ if l := opt.code.len(); l != -1 && l != len(b) {
+ return fmt.Errorf("%s bad length: %d", opt.code, len(b))
+ }
+ switch opt.code {
+ case optLeaseTime:
+ t := binary.BigEndian.Uint32(b)
+ cfg.LeaseLength = time.Duration(t) * time.Second
+ case optSubnetMask:
+ cfg.SubnetMask = tcpip.AddressMask(b)
+ case optDHCPServer:
+ cfg.ServerAddress = tcpip.Address(b)
+ case optDefaultGateway:
+ cfg.Gateway = tcpip.Address(b)
+ case optDomainNameServer:
+ cfg.DomainNameServer = tcpip.Address(b)
+ }
+ }
+ return nil
+}
+
+func (cfg Config) encode() (opts []option) {
+ if cfg.ServerAddress != "" {
+ opts = append(opts, option{optDHCPServer, []byte(cfg.ServerAddress)})
+ }
+ if cfg.SubnetMask != "" {
+ opts = append(opts, option{optSubnetMask, []byte(cfg.SubnetMask)})
+ }
+ if cfg.Gateway != "" {
+ opts = append(opts, option{optDefaultGateway, []byte(cfg.Gateway)})
+ }
+ if cfg.DomainNameServer != "" {
+ opts = append(opts, option{optDomainNameServer, []byte(cfg.DomainNameServer)})
+ }
+ if l := cfg.LeaseLength / time.Second; l != 0 {
+ v := make([]byte, 4)
+ v[0] = byte(l >> 24)
+ v[1] = byte(l >> 16)
+ v[2] = byte(l >> 8)
+ v[3] = byte(l >> 0)
+ opts = append(opts, option{optLeaseTime, v})
+ }
+ return opts
+}
+
+const (
+ serverPort = 67
+ clientPort = 68
+)
+
+var magicCookie = []byte{99, 130, 83, 99} // RFC 1497
+
+type xid uint32
+
+type header []byte
+
+func (h header) init() {
+ h[1] = 0x01 // htype
+ h[2] = 0x06 // hlen
+ h[3] = 0x00 // hops
+ h[8], h[9] = 0, 0 // secs
+ copy(h[236:240], magicCookie)
+}
+
+func (h header) isValid() bool {
+ if len(h) < 241 {
+ return false
+ }
+ if o := h.op(); o != opRequest && o != opReply {
+ return false
+ }
+ if h[1] != 0x01 || h[2] != 0x06 || h[3] != 0x00 {
+ return false
+ }
+ return bytes.Equal(h[236:240], magicCookie) && h[len(h)-1] == 0
+}
+
+func (h header) op() op { return op(h[0]) }
+func (h header) setOp(o op) { h[0] = byte(o) }
+func (h header) xidbytes() []byte { return h[4:8] }
+func (h header) xid() xid { return xid(h[4])<<24 | xid(h[5])<<16 | xid(h[6])<<8 | xid(h[7]) }
+func (h header) setBroadcast() { h[10], h[11] = 0x80, 0x00 } // flags top bit
+func (h header) ciaddr() []byte { return h[12:16] }
+func (h header) yiaddr() []byte { return h[16:20] }
+func (h header) siaddr() []byte { return h[20:24] }
+func (h header) giaddr() []byte { return h[24:28] }
+func (h header) chaddr() []byte { return h[28:44] }
+func (h header) sname() []byte { return h[44:108] }
+func (h header) file() []byte { return h[108:236] }
+
+func (h header) options() (opts options, err error) {
+ i := headerBaseSize
+ for i < len(h) {
+ if h[i] == 0 {
+ i++
+ continue
+ }
+ if h[i] == 255 {
+ break
+ }
+ if len(h) <= i+1 {
+ return nil, fmt.Errorf("option missing length")
+ }
+ optlen := int(h[i+1])
+ if len(h) < i+2+optlen {
+ return nil, fmt.Errorf("option too long")
+ }
+ opts = append(opts, option{
+ code: optionCode(h[i]),
+ body: h[i+2 : i+2+optlen],
+ })
+ i += 2 + optlen
+ }
+ return opts, nil
+}
+
+func (h header) setOptions(opts []option) {
+ i := headerBaseSize
+ for _, opt := range opts {
+ h[i] = byte(opt.code)
+ h[i+1] = byte(len(opt.body))
+ copy(h[i+2:i+2+len(opt.body)], opt.body)
+ i += 2 + len(opt.body)
+ }
+ for ; i < len(h); i++ {
+ h[i] = 0
+ }
+}
+
+// headerBaseSize is the size of a DHCP packet, including the magic cookie.
+//
+// Note that a DHCP packet is required to have an 'end' option that takes
+// up an extra byte, so the minimum DHCP packet size is headerBaseSize + 1.
+const headerBaseSize = 240
+
+type option struct {
+ code optionCode
+ body []byte
+}
+
+type optionCode byte
+
+const (
+ optSubnetMask optionCode = 1
+ optDefaultGateway optionCode = 3
+ optDomainNameServer optionCode = 6
+ optReqIPAddr optionCode = 50
+ optLeaseTime optionCode = 51
+ optDHCPMsgType optionCode = 53 // dhcpMsgType
+ optDHCPServer optionCode = 54
+ optParamReq optionCode = 55
+)
+
+func (code optionCode) len() int {
+ switch code {
+ case optSubnetMask, optDefaultGateway, optDomainNameServer,
+ optReqIPAddr, optLeaseTime, optDHCPServer:
+ return 4
+ case optDHCPMsgType:
+ return 1
+ case optParamReq:
+ return -1 // no fixed length
+ default:
+ return -1
+ }
+}
+
+func (code optionCode) String() string {
+ switch code {
+ case optSubnetMask:
+ return "option(subnet-mask)"
+ case optDefaultGateway:
+ return "option(default-gateway)"
+ case optDomainNameServer:
+ return "option(dns)"
+ case optReqIPAddr:
+ return "option(request-ip-address)"
+ case optLeaseTime:
+ return "option(least-time)"
+ case optDHCPMsgType:
+ return "option(message-type)"
+ case optDHCPServer:
+ return "option(server)"
+ case optParamReq:
+ return "option(parameter-request)"
+ default:
+ return fmt.Sprintf("option(%d)", code)
+ }
+}
+
+type options []option
+
+func (opts options) dhcpMsgType() (dhcpMsgType, error) {
+ for _, opt := range opts {
+ if opt.code == optDHCPMsgType {
+ if len(opt.body) != 1 {
+ return 0, fmt.Errorf("%s: bad length: %d", optDHCPMsgType, len(opt.body))
+ }
+ v := opt.body[0]
+ if v <= 0 || v >= 8 {
+ return 0, fmt.Errorf("%s: unknown value: %d", optDHCPMsgType, v)
+ }
+ return dhcpMsgType(v), nil
+ }
+ }
+ return 0, nil
+}
+
+func (opts options) len() int {
+ l := 0
+ for _, opt := range opts {
+ l += 1 + 1 + len(opt.body) // code + len + body
+ }
+ return l + 1 // extra byte for 'pad' option
+}
+
+type op byte
+
+const (
+ opRequest op = 0x01
+ opReply op = 0x02
+)
+
+// dhcpMsgType is the DHCP Message Type from RFC 1533, section 9.4.
+type dhcpMsgType byte
+
+const (
+ dhcpDISCOVER dhcpMsgType = 1
+ dhcpOFFER dhcpMsgType = 2
+ dhcpREQUEST dhcpMsgType = 3
+ dhcpDECLINE dhcpMsgType = 4
+ dhcpACK dhcpMsgType = 5
+ dhcpNAK dhcpMsgType = 6
+ dhcpRELEASE dhcpMsgType = 7
+)
diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go
new file mode 100644
index 000000000..d56b93997
--- /dev/null
+++ b/pkg/dhcp/dhcp_test.go
@@ -0,0 +1,113 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dhcp
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+)
+
+func TestDHCP(t *testing.T) {
+ const defaultMTU = 65536
+ id, linkEP := channel.New(256, defaultMTU, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+
+ go func() {
+ for pkt := range linkEP.C {
+ v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
+ copy(v, pkt.Header)
+ copy(v[len(pkt.Header):], pkt.Payload)
+ vv := v.ToVectorisedView([1]buffer.View{})
+ linkEP.Inject(pkt.Proto, &vv)
+ }
+ }()
+
+ s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName})
+
+ const nicid tcpip.NICID = 1
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil {
+ t.Fatal(err)
+ }
+ const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil {
+ t.Fatal(err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: tcpip.Address(strings.Repeat("\x00", 4)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 4)),
+ Gateway: "",
+ NIC: nicid,
+ }})
+
+ var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
+
+ serverCfg := Config{
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DomainNameServer: "\x08\x08\x08\x08",
+ LeaseLength: 24 * time.Hour,
+ }
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c0 := NewClient(s, nicid, clientLinkAddr0)
+ if err := c0.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+ if got, want := c0.Address(), clientAddrs[0]; got != want {
+ t.Errorf("c.Addr()=%s, want=%s", got, want)
+ }
+ if err := c0.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+ if got, want := c0.Address(), clientAddrs[0]; got != want {
+ t.Errorf("c.Addr()=%s, want=%s", got, want)
+ }
+
+ const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
+ c1 := NewClient(s, nicid, clientLinkAddr1)
+ if err := c1.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+ if got, want := c1.Address(), clientAddrs[1]; got != want {
+ t.Errorf("c.Addr()=%s, want=%s", got, want)
+ }
+
+ if err := c0.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+ if got, want := c0.Address(), clientAddrs[0]; got != want {
+ t.Errorf("c.Addr()=%s, want=%s", got, want)
+ }
+
+ if got, want := c0.Config(), serverCfg; got != want {
+ t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want)
+ }
+}
diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go
new file mode 100644
index 000000000..d132d90b4
--- /dev/null
+++ b/pkg/dhcp/server.go
@@ -0,0 +1,289 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dhcp
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Server is a DHCP server.
+type Server struct {
+ stack *stack.Stack
+ broadcast tcpip.FullAddress
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
+ cfg Config
+ cfgopts []option // cfg to send to client
+
+ handlers []chan header
+
+ mu sync.Mutex
+ leases map[tcpip.LinkAddress]serverLease
+}
+
+// NewServer creates a new DHCP server and begins serving.
+// The server continues serving until ctx is done.
+func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ s := &Server{
+ stack: stack,
+ addrs: addrs,
+ cfg: cfg,
+ cfgopts: cfg.encode(),
+ broadcast: tcpip.FullAddress{
+ Addr: "\xff\xff\xff\xff",
+ Port: clientPort,
+ },
+
+ handlers: make([]chan header, 8),
+ leases: make(map[tcpip.LinkAddress]serverLease),
+ }
+
+ var err *tcpip.Error
+ s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq)
+ if err != nil {
+ return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
+ }
+ serverBroadcast := tcpip.FullAddress{
+ Addr: "",
+ Port: serverPort,
+ }
+ if err := s.ep.Bind(serverBroadcast, nil); err != nil {
+ return nil, fmt.Errorf("dhcp: server bind: %v", err)
+ }
+
+ for i := 0; i < len(s.handlers); i++ {
+ ch := make(chan header, 8)
+ s.handlers[i] = ch
+ go s.handler(ctx, ch)
+ }
+
+ go s.expirer(ctx)
+ go s.reader(ctx)
+ return s, nil
+}
+
+func (s *Server) expirer(ctx context.Context) {
+ t := time.NewTicker(1 * time.Minute)
+ defer t.Stop()
+ for {
+ select {
+ case <-t.C:
+ s.mu.Lock()
+ for linkAddr, lease := range s.leases {
+ if time.Since(lease.start) > s.cfg.LeaseLength {
+ lease.state = leaseExpired
+ s.leases[linkAddr] = lease
+ }
+ }
+ s.mu.Unlock()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+// reader listens for all incoming DHCP packets and fans them out to
+// handling goroutines based on XID as session identifiers.
+func (s *Server) reader(ctx context.Context) {
+ we, ch := waiter.NewChannelEntry(nil)
+ s.wq.EventRegister(&we, waiter.EventIn)
+ defer s.wq.EventUnregister(&we)
+
+ for {
+ var addr tcpip.FullAddress
+ v, err := s.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-ch:
+ continue
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ h := header(v)
+ if !h.isValid() || h.op() != opRequest {
+ continue
+ }
+ xid := h.xid()
+
+ // Fan out the packet to a handler goroutine.
+ //
+ // Use a consistent handler for a given xid, so that
+ // packets from a particular client are processed
+ // in order.
+ ch := s.handlers[int(xid)%len(s.handlers)]
+ select {
+ case <-ctx.Done():
+ return
+ case ch <- h:
+ default:
+ // drop the packet
+ }
+ }
+}
+
+func (s *Server) handler(ctx context.Context, ch chan header) {
+ for {
+ select {
+ case h := <-ch:
+ if h == nil {
+ return
+ }
+ opts, err := h.options()
+ if err != nil {
+ continue
+ }
+ // TODO: Handle DHCPRELEASE and DHCPDECLINE.
+ msgtype, err := opts.dhcpMsgType()
+ if err != nil {
+ continue
+ }
+ switch msgtype {
+ case dhcpDISCOVER:
+ s.handleDiscover(h, opts)
+ case dhcpREQUEST:
+ s.handleRequest(h, opts)
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (s *Server) handleDiscover(hreq header, opts options) {
+ linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
+ xid := hreq.xid()
+
+ s.mu.Lock()
+ lease := s.leases[linkAddr]
+ switch lease.state {
+ case leaseNew:
+ if len(s.leases) < len(s.addrs) {
+ // Find an unused address.
+ // TODO: avoid building this state on each request.
+ alloced := make(map[tcpip.Address]bool)
+ for _, lease := range s.leases {
+ alloced[lease.addr] = true
+ }
+ for _, addr := range s.addrs {
+ if !alloced[addr] {
+ lease = serverLease{
+ start: time.Now(),
+ addr: addr,
+ xid: xid,
+ state: leaseOffer,
+ }
+ s.leases[linkAddr] = lease
+ break
+ }
+ }
+ } else {
+ // No more addresses, take an expired address.
+ for k, oldLease := range s.leases {
+ if oldLease.state == leaseExpired {
+ delete(s.leases, k)
+ lease = serverLease{
+ start: time.Now(),
+ addr: lease.addr,
+ xid: xid,
+ state: leaseOffer,
+ }
+ s.leases[linkAddr] = lease
+ break
+ }
+ }
+ log.Printf("server has no more addresses")
+ s.mu.Unlock()
+ return
+ }
+ case leaseOffer, leaseAck, leaseExpired:
+ lease = serverLease{
+ start: time.Now(),
+ addr: s.leases[linkAddr].addr,
+ xid: xid,
+ state: leaseOffer,
+ }
+ s.leases[linkAddr] = lease
+ }
+ s.mu.Unlock()
+
+ // DHCPOFFER
+ opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}}
+ opts = append(opts, s.cfgopts...)
+ h := make(header, headerBaseSize+opts.len())
+ h.init()
+ h.setOp(opReply)
+ copy(h.xidbytes(), hreq.xidbytes())
+ copy(h.yiaddr(), lease.addr)
+ copy(h.siaddr(), s.cfg.ServerAddress)
+ copy(h.chaddr(), hreq.chaddr())
+ h.setOptions(opts)
+ s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+}
+
+func (s *Server) handleRequest(hreq header, opts options) {
+ linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
+ xid := hreq.xid()
+
+ s.mu.Lock()
+ lease := s.leases[linkAddr]
+ switch lease.state {
+ case leaseOffer, leaseAck, leaseExpired:
+ lease = serverLease{
+ start: time.Now(),
+ addr: s.leases[linkAddr].addr,
+ xid: xid,
+ state: leaseAck,
+ }
+ s.leases[linkAddr] = lease
+ }
+ s.mu.Unlock()
+
+ if lease.state == leaseNew {
+ // TODO: NACK or accept request
+ return
+ }
+
+ // DHCPACK
+ opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}}
+ opts = append(opts, s.cfgopts...)
+ h := make(header, headerBaseSize+opts.len())
+ h.init()
+ h.setOp(opReply)
+ copy(h.xidbytes(), hreq.xidbytes())
+ copy(h.yiaddr(), lease.addr)
+ copy(h.siaddr(), s.cfg.ServerAddress)
+ copy(h.chaddr(), hreq.chaddr())
+ h.setOptions(opts)
+ s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+}
+
+type leaseState int
+
+const (
+ leaseNew leaseState = iota
+ leaseOffer
+ leaseAck
+ leaseExpired
+)
+
+type serverLease struct {
+ start time.Time
+ addr tcpip.Address
+ xid xid
+ state leaseState
+}