diff options
author | Googler <noreply@google.com> | 2018-04-27 10:37:02 -0700 |
---|---|---|
committer | Adin Scannell <ascannell@google.com> | 2018-04-28 01:44:26 -0400 |
commit | d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch) | |
tree | 54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/dhcp | |
parent | f70210e742919f40aa2f0934a22f1c9ba6dada62 (diff) |
Check in gVisor.
PiperOrigin-RevId: 194583126
Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/dhcp')
-rw-r--r-- | pkg/dhcp/BUILD | 36 | ||||
-rw-r--r-- | pkg/dhcp/client.go | 274 | ||||
-rw-r--r-- | pkg/dhcp/dhcp.go | 263 | ||||
-rw-r--r-- | pkg/dhcp/dhcp_test.go | 113 | ||||
-rw-r--r-- | pkg/dhcp/server.go | 289 |
5 files changed, 975 insertions, 0 deletions
diff --git a/pkg/dhcp/BUILD b/pkg/dhcp/BUILD new file mode 100644 index 000000000..b40860aac --- /dev/null +++ b/pkg/dhcp/BUILD @@ -0,0 +1,36 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "dhcp", + srcs = [ + "client.go", + "dhcp.go", + "server.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/dhcp", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + ], +) + +go_test( + name = "dhcp_test", + size = "small", + srcs = ["dhcp_test.go"], + embed = [":dhcp"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + ], +) diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go new file mode 100644 index 000000000..9a4fd7ae4 --- /dev/null +++ b/pkg/dhcp/client.go @@ -0,0 +1,274 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dhcp + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Client is a DHCP client. +type Client struct { + stack *stack.Stack + nicid tcpip.NICID + linkAddr tcpip.LinkAddress + + mu sync.Mutex + addr tcpip.Address + cfg Config + lease time.Duration + cancelRenew func() +} + +// NewClient creates a DHCP client. +// +// TODO: add s.LinkAddr(nicid) to *stack.Stack. +func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client { + return &Client{ + stack: s, + nicid: nicid, + linkAddr: linkAddr, + } +} + +// Start starts the DHCP client. +// It will periodically search for an IP address using the Request method. +func (c *Client) Start() { + go func() { + for { + log.Print("DHCP request") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := c.Request(ctx, "") + cancel() + if err == nil { + break + } + } + log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength) + }() +} + +// Address reports the IP address acquired by the DHCP client. +func (c *Client) Address() tcpip.Address { + c.mu.Lock() + defer c.mu.Unlock() + return c.addr +} + +// Config reports the DHCP configuration acquired with the IP address lease. +func (c *Client) Config() Config { + c.mu.Lock() + defer c.mu.Unlock() + return c.cfg +} + +// Shutdown relinquishes any lease and ends any outstanding renewal timers. +func (c *Client) Shutdown() { + c.mu.Lock() + defer c.mu.Unlock() + if c.addr != "" { + c.stack.RemoveAddress(c.nicid, c.addr) + } + if c.cancelRenew != nil { + c.cancelRenew() + } +} + +// Request executes a DHCP request session. +// +// On success, it adds a new address to this client's TCPIP stack. +// If the server sets a lease limit a timer is set to automatically +// renew it. +func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error { + var wq waiter.Queue + ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + return fmt.Errorf("dhcp: outbound endpoint: %v", err) + } + err = ep.Bind(tcpip.FullAddress{ + Addr: "\x00\x00\x00\x00", + Port: clientPort, + }, nil) + defer ep.Close() + if err != nil { + return fmt.Errorf("dhcp: connect failed: %v", err) + } + + epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + return fmt.Errorf("dhcp: inbound endpoint: %v", err) + } + err = epin.Bind(tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: clientPort, + }, nil) + defer epin.Close() + if err != nil { + return fmt.Errorf("dhcp: connect failed: %v", err) + } + + var xid [4]byte + rand.Read(xid[:]) + + // DHCPDISCOVERY + options := options{ + {optDHCPMsgType, []byte{byte(dhcpDISCOVER)}}, + {optParamReq, []byte{ + 1, // request subnet mask + 3, // request router + 15, // domain name + 6, // domain name server + }}, + } + if requestedAddr != "" { + options = append(options, option{optReqIPAddr, []byte(requestedAddr)}) + } + h := make(header, headerBaseSize+options.len()) + h.init() + h.setOp(opRequest) + copy(h.xidbytes(), xid[:]) + h.setBroadcast() + copy(h.chaddr(), c.linkAddr) + h.setOptions(options) + + serverAddr := &tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: serverPort, + } + wopts := tcpip.WriteOptions{ + To: serverAddr, + } + if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { + return fmt.Errorf("dhcp discovery write: %v", err) + } + + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + // DHCPOFFER + for { + var addr tcpip.FullAddress + v, err := epin.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) + } + } + h = header(v) + if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { + break + } + } + if _, err := h.options(); err != nil { + return fmt.Errorf("dhcp offer: %v", err) + } + + var ack bool + var cfg Config + + // DHCPREQUEST + addr := tcpip.Address(h.yiaddr()) + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil { + if err != tcpip.ErrDuplicateAddress { + return fmt.Errorf("adding address: %v", err) + } + } + defer func() { + if ack { + c.mu.Lock() + c.addr = addr + c.cfg = cfg + c.mu.Unlock() + } else { + c.stack.RemoveAddress(c.nicid, addr) + } + }() + h.setOp(opRequest) + for i, b := 0, h.yiaddr(); i < len(b); i++ { + b[i] = 0 + } + h.setOptions([]option{ + {optDHCPMsgType, []byte{byte(dhcpREQUEST)}}, + {optReqIPAddr, []byte(addr)}, + {optDHCPServer, h.siaddr()}, + }) + if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { + return fmt.Errorf("dhcp discovery write: %v", err) + } + + // DHCPACK + for { + var addr tcpip.FullAddress + v, err := epin.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) + } + } + h = header(v) + if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { + break + } + } + opts, e := h.options() + if e != nil { + return fmt.Errorf("dhcp ack: %v", e) + } + if err := cfg.decode(opts); err != nil { + return fmt.Errorf("dhcp ack bad options: %v", err) + } + msgtype, e := opts.dhcpMsgType() + if e != nil { + return fmt.Errorf("dhcp ack: %v", e) + } + ack = msgtype == dhcpACK + if !ack { + return fmt.Errorf("dhcp: request not acknowledged") + } + if cfg.LeaseLength != 0 { + go c.renewAfter(cfg.LeaseLength) + } + return nil +} + +func (c *Client) renewAfter(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + if c.cancelRenew != nil { + c.cancelRenew() + } + ctx, cancel := context.WithCancel(context.Background()) + c.cancelRenew = cancel + go func() { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + if err := c.Request(ctx, c.addr); err != nil { + log.Printf("address renewal failed: %v", err) + go c.renewAfter(1 * time.Minute) + } + } + }() +} diff --git a/pkg/dhcp/dhcp.go b/pkg/dhcp/dhcp.go new file mode 100644 index 000000000..762086853 --- /dev/null +++ b/pkg/dhcp/dhcp.go @@ -0,0 +1,263 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dhcp implements a DHCP client and server as described in RFC 2131. +package dhcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +// Config is standard DHCP configuration. +type Config struct { + ServerAddress tcpip.Address // address of the server + SubnetMask tcpip.AddressMask // client address subnet mask + Gateway tcpip.Address // client default gateway + DomainNameServer tcpip.Address // client domain name server + LeaseLength time.Duration // length of the address lease +} + +func (cfg *Config) decode(opts []option) error { + *cfg = Config{} + for _, opt := range opts { + b := opt.body + if l := opt.code.len(); l != -1 && l != len(b) { + return fmt.Errorf("%s bad length: %d", opt.code, len(b)) + } + switch opt.code { + case optLeaseTime: + t := binary.BigEndian.Uint32(b) + cfg.LeaseLength = time.Duration(t) * time.Second + case optSubnetMask: + cfg.SubnetMask = tcpip.AddressMask(b) + case optDHCPServer: + cfg.ServerAddress = tcpip.Address(b) + case optDefaultGateway: + cfg.Gateway = tcpip.Address(b) + case optDomainNameServer: + cfg.DomainNameServer = tcpip.Address(b) + } + } + return nil +} + +func (cfg Config) encode() (opts []option) { + if cfg.ServerAddress != "" { + opts = append(opts, option{optDHCPServer, []byte(cfg.ServerAddress)}) + } + if cfg.SubnetMask != "" { + opts = append(opts, option{optSubnetMask, []byte(cfg.SubnetMask)}) + } + if cfg.Gateway != "" { + opts = append(opts, option{optDefaultGateway, []byte(cfg.Gateway)}) + } + if cfg.DomainNameServer != "" { + opts = append(opts, option{optDomainNameServer, []byte(cfg.DomainNameServer)}) + } + if l := cfg.LeaseLength / time.Second; l != 0 { + v := make([]byte, 4) + v[0] = byte(l >> 24) + v[1] = byte(l >> 16) + v[2] = byte(l >> 8) + v[3] = byte(l >> 0) + opts = append(opts, option{optLeaseTime, v}) + } + return opts +} + +const ( + serverPort = 67 + clientPort = 68 +) + +var magicCookie = []byte{99, 130, 83, 99} // RFC 1497 + +type xid uint32 + +type header []byte + +func (h header) init() { + h[1] = 0x01 // htype + h[2] = 0x06 // hlen + h[3] = 0x00 // hops + h[8], h[9] = 0, 0 // secs + copy(h[236:240], magicCookie) +} + +func (h header) isValid() bool { + if len(h) < 241 { + return false + } + if o := h.op(); o != opRequest && o != opReply { + return false + } + if h[1] != 0x01 || h[2] != 0x06 || h[3] != 0x00 { + return false + } + return bytes.Equal(h[236:240], magicCookie) && h[len(h)-1] == 0 +} + +func (h header) op() op { return op(h[0]) } +func (h header) setOp(o op) { h[0] = byte(o) } +func (h header) xidbytes() []byte { return h[4:8] } +func (h header) xid() xid { return xid(h[4])<<24 | xid(h[5])<<16 | xid(h[6])<<8 | xid(h[7]) } +func (h header) setBroadcast() { h[10], h[11] = 0x80, 0x00 } // flags top bit +func (h header) ciaddr() []byte { return h[12:16] } +func (h header) yiaddr() []byte { return h[16:20] } +func (h header) siaddr() []byte { return h[20:24] } +func (h header) giaddr() []byte { return h[24:28] } +func (h header) chaddr() []byte { return h[28:44] } +func (h header) sname() []byte { return h[44:108] } +func (h header) file() []byte { return h[108:236] } + +func (h header) options() (opts options, err error) { + i := headerBaseSize + for i < len(h) { + if h[i] == 0 { + i++ + continue + } + if h[i] == 255 { + break + } + if len(h) <= i+1 { + return nil, fmt.Errorf("option missing length") + } + optlen := int(h[i+1]) + if len(h) < i+2+optlen { + return nil, fmt.Errorf("option too long") + } + opts = append(opts, option{ + code: optionCode(h[i]), + body: h[i+2 : i+2+optlen], + }) + i += 2 + optlen + } + return opts, nil +} + +func (h header) setOptions(opts []option) { + i := headerBaseSize + for _, opt := range opts { + h[i] = byte(opt.code) + h[i+1] = byte(len(opt.body)) + copy(h[i+2:i+2+len(opt.body)], opt.body) + i += 2 + len(opt.body) + } + for ; i < len(h); i++ { + h[i] = 0 + } +} + +// headerBaseSize is the size of a DHCP packet, including the magic cookie. +// +// Note that a DHCP packet is required to have an 'end' option that takes +// up an extra byte, so the minimum DHCP packet size is headerBaseSize + 1. +const headerBaseSize = 240 + +type option struct { + code optionCode + body []byte +} + +type optionCode byte + +const ( + optSubnetMask optionCode = 1 + optDefaultGateway optionCode = 3 + optDomainNameServer optionCode = 6 + optReqIPAddr optionCode = 50 + optLeaseTime optionCode = 51 + optDHCPMsgType optionCode = 53 // dhcpMsgType + optDHCPServer optionCode = 54 + optParamReq optionCode = 55 +) + +func (code optionCode) len() int { + switch code { + case optSubnetMask, optDefaultGateway, optDomainNameServer, + optReqIPAddr, optLeaseTime, optDHCPServer: + return 4 + case optDHCPMsgType: + return 1 + case optParamReq: + return -1 // no fixed length + default: + return -1 + } +} + +func (code optionCode) String() string { + switch code { + case optSubnetMask: + return "option(subnet-mask)" + case optDefaultGateway: + return "option(default-gateway)" + case optDomainNameServer: + return "option(dns)" + case optReqIPAddr: + return "option(request-ip-address)" + case optLeaseTime: + return "option(least-time)" + case optDHCPMsgType: + return "option(message-type)" + case optDHCPServer: + return "option(server)" + case optParamReq: + return "option(parameter-request)" + default: + return fmt.Sprintf("option(%d)", code) + } +} + +type options []option + +func (opts options) dhcpMsgType() (dhcpMsgType, error) { + for _, opt := range opts { + if opt.code == optDHCPMsgType { + if len(opt.body) != 1 { + return 0, fmt.Errorf("%s: bad length: %d", optDHCPMsgType, len(opt.body)) + } + v := opt.body[0] + if v <= 0 || v >= 8 { + return 0, fmt.Errorf("%s: unknown value: %d", optDHCPMsgType, v) + } + return dhcpMsgType(v), nil + } + } + return 0, nil +} + +func (opts options) len() int { + l := 0 + for _, opt := range opts { + l += 1 + 1 + len(opt.body) // code + len + body + } + return l + 1 // extra byte for 'pad' option +} + +type op byte + +const ( + opRequest op = 0x01 + opReply op = 0x02 +) + +// dhcpMsgType is the DHCP Message Type from RFC 1533, section 9.4. +type dhcpMsgType byte + +const ( + dhcpDISCOVER dhcpMsgType = 1 + dhcpOFFER dhcpMsgType = 2 + dhcpREQUEST dhcpMsgType = 3 + dhcpDECLINE dhcpMsgType = 4 + dhcpACK dhcpMsgType = 5 + dhcpNAK dhcpMsgType = 6 + dhcpRELEASE dhcpMsgType = 7 +) diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go new file mode 100644 index 000000000..d56b93997 --- /dev/null +++ b/pkg/dhcp/dhcp_test.go @@ -0,0 +1,113 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dhcp + +import ( + "context" + "strings" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" +) + +func TestDHCP(t *testing.T) { + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, "") + if testing.Verbose() { + id = sniffer.New(id) + } + + go func() { + for pkt := range linkEP.C { + v := make(buffer.View, len(pkt.Header)+len(pkt.Payload)) + copy(v, pkt.Header) + copy(v[len(pkt.Header):], pkt.Payload) + vv := v.ToVectorisedView([1]buffer.View{}) + linkEP.Inject(pkt.Proto, &vv) + } + }() + + s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}) + + const nicid tcpip.NICID = 1 + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal(err) + } + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil { + t.Fatal(err) + } + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil { + t.Fatal(err) + } + const serverAddr = tcpip.Address("\xc0\xa8\x03\x01") + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil { + t.Fatal(err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: tcpip.Address(strings.Repeat("\x00", 4)), + Mask: tcpip.Address(strings.Repeat("\x00", 4)), + Gateway: "", + NIC: nicid, + }}) + + var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"} + + serverCfg := Config{ + ServerAddress: serverAddr, + SubnetMask: "\xff\xff\xff\x00", + Gateway: "\xc0\xa8\x03\xF0", + DomainNameServer: "\x08\x08\x08\x08", + LeaseLength: 24 * time.Hour, + } + serverCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := NewServer(serverCtx, s, clientAddrs, serverCfg) + if err != nil { + t.Fatal(err) + } + + const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52") + c0 := NewClient(s, nicid, clientLinkAddr0) + if err := c0.Request(context.Background(), ""); err != nil { + t.Fatal(err) + } + if got, want := c0.Address(), clientAddrs[0]; got != want { + t.Errorf("c.Addr()=%s, want=%s", got, want) + } + if err := c0.Request(context.Background(), ""); err != nil { + t.Fatal(err) + } + if got, want := c0.Address(), clientAddrs[0]; got != want { + t.Errorf("c.Addr()=%s, want=%s", got, want) + } + + const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53") + c1 := NewClient(s, nicid, clientLinkAddr1) + if err := c1.Request(context.Background(), ""); err != nil { + t.Fatal(err) + } + if got, want := c1.Address(), clientAddrs[1]; got != want { + t.Errorf("c.Addr()=%s, want=%s", got, want) + } + + if err := c0.Request(context.Background(), ""); err != nil { + t.Fatal(err) + } + if got, want := c0.Address(), clientAddrs[0]; got != want { + t.Errorf("c.Addr()=%s, want=%s", got, want) + } + + if got, want := c0.Config(), serverCfg; got != want { + t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want) + } +} diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go new file mode 100644 index 000000000..d132d90b4 --- /dev/null +++ b/pkg/dhcp/server.go @@ -0,0 +1,289 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dhcp + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Server is a DHCP server. +type Server struct { + stack *stack.Stack + broadcast tcpip.FullAddress + wq waiter.Queue + ep tcpip.Endpoint + addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure + cfg Config + cfgopts []option // cfg to send to client + + handlers []chan header + + mu sync.Mutex + leases map[tcpip.LinkAddress]serverLease +} + +// NewServer creates a new DHCP server and begins serving. +// The server continues serving until ctx is done. +func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) { + s := &Server{ + stack: stack, + addrs: addrs, + cfg: cfg, + cfgopts: cfg.encode(), + broadcast: tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: clientPort, + }, + + handlers: make([]chan header, 8), + leases: make(map[tcpip.LinkAddress]serverLease), + } + + var err *tcpip.Error + s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq) + if err != nil { + return nil, fmt.Errorf("dhcp: server endpoint: %v", err) + } + serverBroadcast := tcpip.FullAddress{ + Addr: "", + Port: serverPort, + } + if err := s.ep.Bind(serverBroadcast, nil); err != nil { + return nil, fmt.Errorf("dhcp: server bind: %v", err) + } + + for i := 0; i < len(s.handlers); i++ { + ch := make(chan header, 8) + s.handlers[i] = ch + go s.handler(ctx, ch) + } + + go s.expirer(ctx) + go s.reader(ctx) + return s, nil +} + +func (s *Server) expirer(ctx context.Context) { + t := time.NewTicker(1 * time.Minute) + defer t.Stop() + for { + select { + case <-t.C: + s.mu.Lock() + for linkAddr, lease := range s.leases { + if time.Since(lease.start) > s.cfg.LeaseLength { + lease.state = leaseExpired + s.leases[linkAddr] = lease + } + } + s.mu.Unlock() + case <-ctx.Done(): + return + } + } +} + +// reader listens for all incoming DHCP packets and fans them out to +// handling goroutines based on XID as session identifiers. +func (s *Server) reader(ctx context.Context) { + we, ch := waiter.NewChannelEntry(nil) + s.wq.EventRegister(&we, waiter.EventIn) + defer s.wq.EventUnregister(&we) + + for { + var addr tcpip.FullAddress + v, err := s.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return + } + } + + h := header(v) + if !h.isValid() || h.op() != opRequest { + continue + } + xid := h.xid() + + // Fan out the packet to a handler goroutine. + // + // Use a consistent handler for a given xid, so that + // packets from a particular client are processed + // in order. + ch := s.handlers[int(xid)%len(s.handlers)] + select { + case <-ctx.Done(): + return + case ch <- h: + default: + // drop the packet + } + } +} + +func (s *Server) handler(ctx context.Context, ch chan header) { + for { + select { + case h := <-ch: + if h == nil { + return + } + opts, err := h.options() + if err != nil { + continue + } + // TODO: Handle DHCPRELEASE and DHCPDECLINE. + msgtype, err := opts.dhcpMsgType() + if err != nil { + continue + } + switch msgtype { + case dhcpDISCOVER: + s.handleDiscover(h, opts) + case dhcpREQUEST: + s.handleRequest(h, opts) + } + case <-ctx.Done(): + return + } + } +} + +func (s *Server) handleDiscover(hreq header, opts options) { + linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) + xid := hreq.xid() + + s.mu.Lock() + lease := s.leases[linkAddr] + switch lease.state { + case leaseNew: + if len(s.leases) < len(s.addrs) { + // Find an unused address. + // TODO: avoid building this state on each request. + alloced := make(map[tcpip.Address]bool) + for _, lease := range s.leases { + alloced[lease.addr] = true + } + for _, addr := range s.addrs { + if !alloced[addr] { + lease = serverLease{ + start: time.Now(), + addr: addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + break + } + } + } else { + // No more addresses, take an expired address. + for k, oldLease := range s.leases { + if oldLease.state == leaseExpired { + delete(s.leases, k) + lease = serverLease{ + start: time.Now(), + addr: lease.addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + break + } + } + log.Printf("server has no more addresses") + s.mu.Unlock() + return + } + case leaseOffer, leaseAck, leaseExpired: + lease = serverLease{ + start: time.Now(), + addr: s.leases[linkAddr].addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + } + s.mu.Unlock() + + // DHCPOFFER + opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}} + opts = append(opts, s.cfgopts...) + h := make(header, headerBaseSize+opts.len()) + h.init() + h.setOp(opReply) + copy(h.xidbytes(), hreq.xidbytes()) + copy(h.yiaddr(), lease.addr) + copy(h.siaddr(), s.cfg.ServerAddress) + copy(h.chaddr(), hreq.chaddr()) + h.setOptions(opts) + s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) +} + +func (s *Server) handleRequest(hreq header, opts options) { + linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) + xid := hreq.xid() + + s.mu.Lock() + lease := s.leases[linkAddr] + switch lease.state { + case leaseOffer, leaseAck, leaseExpired: + lease = serverLease{ + start: time.Now(), + addr: s.leases[linkAddr].addr, + xid: xid, + state: leaseAck, + } + s.leases[linkAddr] = lease + } + s.mu.Unlock() + + if lease.state == leaseNew { + // TODO: NACK or accept request + return + } + + // DHCPACK + opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}} + opts = append(opts, s.cfgopts...) + h := make(header, headerBaseSize+opts.len()) + h.init() + h.setOp(opReply) + copy(h.xidbytes(), hreq.xidbytes()) + copy(h.yiaddr(), lease.addr) + copy(h.siaddr(), s.cfg.ServerAddress) + copy(h.chaddr(), hreq.chaddr()) + h.setOptions(opts) + s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) +} + +type leaseState int + +const ( + leaseNew leaseState = iota + leaseOffer + leaseAck + leaseExpired +) + +type serverLease struct { + start time.Time + addr tcpip.Address + xid xid + state leaseState +} |