summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@gmail.com>2018-08-25 06:16:34 -0700
committerShentubot <shentubot@google.com>2018-08-25 06:17:32 -0700
commitb17e80ef5a44e773e9032e7dbcb7438ff851ab7c (patch)
tree5c40dd5e44d70c51c2089ec10b51bd480fb8be50
parent106de2182d34197d76fb68863cd4a102ebac2dbb (diff)
Upstreaming DHCP changes from Fuchsia
PiperOrigin-RevId: 210221388 Change-Id: Ic82d592b8c4778855fa55ba913f6b9a10b2d511f
-rw-r--r--pkg/dhcp/BUILD3
-rw-r--r--pkg/dhcp/client.go285
-rw-r--r--pkg/dhcp/dhcp.go99
-rw-r--r--pkg/dhcp/dhcp_string.go115
-rw-r--r--pkg/dhcp/dhcp_test.go246
-rw-r--r--pkg/dhcp/server.go154
-rw-r--r--pkg/tcpip/stack/stack.go5
7 files changed, 687 insertions, 220 deletions
diff --git a/pkg/dhcp/BUILD b/pkg/dhcp/BUILD
index bd9f592b4..711a72c99 100644
--- a/pkg/dhcp/BUILD
+++ b/pkg/dhcp/BUILD
@@ -7,12 +7,14 @@ go_library(
srcs = [
"client.go",
"dhcp.go",
+ "dhcp_string.go",
"server.go",
],
importpath = "gvisor.googlesource.com/gvisor/pkg/dhcp",
deps = [
"//pkg/rand",
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/udp",
@@ -33,5 +35,6 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
],
)
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go
index 8b5fc0452..909040e79 100644
--- a/pkg/dhcp/client.go
+++ b/pkg/dhcp/client.go
@@ -18,7 +18,6 @@ import (
"bytes"
"context"
"fmt"
- "log"
"sync"
"time"
@@ -32,9 +31,10 @@ import (
// Client is a DHCP client.
type Client struct {
- stack *stack.Stack
- nicid tcpip.NICID
- linkAddr tcpip.LinkAddress
+ stack *stack.Stack
+ nicid tcpip.NICID
+ linkAddr tcpip.LinkAddress
+ acquiredFunc func(old, new tcpip.Address, cfg Config)
mu sync.Mutex
addr tcpip.Address
@@ -46,29 +46,57 @@ type Client struct {
// NewClient creates a DHCP client.
//
// TODO: add s.LinkAddr(nicid) to *stack.Stack.
-func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
+func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client {
return &Client{
- stack: s,
- nicid: nicid,
- linkAddr: linkAddr,
+ stack: s,
+ nicid: nicid,
+ linkAddr: linkAddr,
+ acquiredFunc: acquiredFunc,
}
}
-// Start starts the DHCP client.
+// Run starts the DHCP client.
// It will periodically search for an IP address using the Request method.
-func (c *Client) Start() {
- go func() {
- for {
- log.Print("DHCP request")
- ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- err := c.Request(ctx, "")
- cancel()
- if err == nil {
- break
- }
+func (c *Client) Run(ctx context.Context) {
+ go c.run(ctx)
+}
+
+func (c *Client) run(ctx context.Context) {
+ defer func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.addr != "" {
+ c.stack.RemoveAddress(c.nicid, c.addr)
}
- log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength)
}()
+
+ var renewAddr tcpip.Address
+ for {
+ reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ cfg, err := c.Request(reqCtx, renewAddr)
+ cancel()
+ if err != nil {
+ select {
+ case <-time.After(1 * time.Second):
+ // loop and try again
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ c.mu.Lock()
+ renewAddr = c.addr
+ c.mu.Unlock()
+
+ timer := time.NewTimer(cfg.LeaseLength)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // loop and make a renewal request
+ }
+ }
}
// Address reports the IP address acquired by the DHCP client.
@@ -85,56 +113,53 @@ func (c *Client) Config() Config {
return c.cfg
}
-// Shutdown relinquishes any lease and ends any outstanding renewal timers.
-func (c *Client) Shutdown() {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.addr != "" {
- c.stack.RemoveAddress(c.nicid, c.addr)
- }
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
-}
-
// Request executes a DHCP request session.
//
// On success, it adds a new address to this client's TCPIP stack.
// If the server sets a lease limit a timer is set to automatically
// renew it.
-func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
+func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) {
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
var wq waiter.Queue
ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: outbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
}
- err = ep.Bind(tcpip.FullAddress{
- Addr: "\x00\x00\x00\x00",
- Port: clientPort,
- }, nil)
defer ep.Close()
- if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ if err := ep.Bind(tcpip.FullAddress{
+ Addr: "\x00\x00\x00\x00",
+ Port: ClientPort,
+ NIC: c.nicid,
+ }, nil); err != nil {
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: inbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
}
- err = epin.Bind(tcpip.FullAddress{
- Addr: "\xff\xff\xff\xff",
- Port: clientPort,
- }, nil)
defer epin.Close()
- if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ if err := epin.Bind(tcpip.FullAddress{
+ Addr: "\xff\xff\xff\xff",
+ Port: ClientPort,
+ NIC: c.nicid,
+ }, nil); err != nil {
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
var xid [4]byte
rand.Read(xid[:])
// DHCPDISCOVERY
- options := options{
+ discOpts := options{
{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
{optParamReq, []byte{
1, // request subnet mask
@@ -144,25 +169,34 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
}},
}
if requestedAddr != "" {
- options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
+ discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)})
}
- h := make(header, headerBaseSize+options.len())
+ var clientID []byte
+ if len(c.linkAddr) == 6 {
+ clientID = append(
+ []byte{1}, // RFC 1700: Hardware Type [Ethernet = 1]
+ c.linkAddr...,
+ )
+ discOpts = append(discOpts, option{optClientID, clientID})
+ }
+ h := make(header, headerBaseSize+discOpts.len()+1)
h.init()
h.setOp(opRequest)
copy(h.xidbytes(), xid[:])
h.setBroadcast()
copy(h.chaddr(), c.linkAddr)
- h.setOptions(options)
+ h.setOptions(discOpts)
serverAddr := &tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
- Port: serverPort,
+ Port: ServerPort,
+ NIC: c.nicid,
}
wopts := tcpip.WriteOptions{
To: serverAddr,
}
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -170,6 +204,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
defer wq.EventUnregister(&we)
// DHCPOFFER
+ var opts options
for {
var addr tcpip.FullAddress
v, _, err := epin.Read(&addr)
@@ -178,49 +213,84 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpOFFER, xid[:])
+ if !valid {
+ if e != nil {
+ // TODO: handle all the errors?
+ // TODO: report malformed server responses
+ }
+ continue
}
- }
- if _, err := h.options(); err != nil {
- return fmt.Errorf("dhcp offer: %v", err)
+ break
}
var ack bool
- var cfg Config
+ if err := cfg.decode(opts); err != nil {
+ return Config{}, fmt.Errorf("dhcp offer: %v", err)
+ }
// DHCPREQUEST
addr := tcpip.Address(h.yiaddr())
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
if err != tcpip.ErrDuplicateAddress {
- return fmt.Errorf("adding address: %v", err)
+ return Config{}, fmt.Errorf("adding address: %v", err)
}
}
defer func() {
- if ack {
- c.mu.Lock()
- c.addr = addr
- c.cfg = cfg
- c.mu.Unlock()
- } else {
+ if !ack || reterr != nil {
c.stack.RemoveAddress(c.nicid, addr)
+ addr = ""
+ cfg = Config{Error: reterr}
+ }
+
+ c.mu.Lock()
+ oldAddr := c.addr
+ c.addr = addr
+ c.cfg = cfg
+ c.mu.Unlock()
+
+ // Clean up broadcast addresses before calling acquiredFunc
+ // so nothing else uses them by mistake.
+ //
+ // (The deferred RemoveAddress calls above silently error.)
+ c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
+ if c.acquiredFunc != nil {
+ c.acquiredFunc(oldAddr, addr, cfg)
+ }
+ if requestedAddr != "" && requestedAddr != addr {
+ c.stack.RemoveAddress(c.nicid, requestedAddr)
}
}()
+ h.init()
h.setOp(opRequest)
for i, b := 0, h.yiaddr(); i < len(b); i++ {
b[i] = 0
}
- h.setOptions([]option{
+ for i, b := 0, h.siaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ for i, b := 0, h.giaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ reqOpts := []option{
{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
{optReqIPAddr, []byte(addr)},
- {optDHCPServer, h.siaddr()},
- })
+ {optDHCPServer, []byte(cfg.ServerAddress)},
+ }
+ if len(clientID) != 0 {
+ reqOpts = append(reqOpts, option{optClientID, clientID})
+ }
+ h.setOptions(reqOpts)
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
// DHCPACK
@@ -232,53 +302,46 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var e error
+ opts, valid, e = loadDHCPReply(h, dhcpACK, xid[:])
+ if !valid {
+ if e != nil {
+ // TODO: handle all the errors?
+ // TODO: report malformed server responses
+ }
+ if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
+ if msg := opts.message(); msg != "" {
+ return Config{}, fmt.Errorf("dhcp: NAK %q", msg)
+ }
+ return Config{}, fmt.Errorf("dhcp: NAK with no message")
+ }
+ continue
}
+ break
}
- opts, e := h.options()
- if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
- }
- if err := cfg.decode(opts); err != nil {
- return fmt.Errorf("dhcp ack bad options: %v", err)
+ ack = true
+ return cfg, nil
+}
+
+func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
+ if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
+ return nil, false, nil
}
- msgtype, e := opts.dhcpMsgType()
- if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
+ opts, err = h.options()
+ if err != nil {
+ return nil, false, err
}
- ack = msgtype == dhcpACK
- if !ack {
- return fmt.Errorf("dhcp: request not acknowledged")
+ msgtype, err := opts.dhcpMsgType()
+ if err != nil {
+ return nil, false, err
}
- if cfg.LeaseLength != 0 {
- go c.renewAfter(cfg.LeaseLength)
+ if msgtype != typ {
+ return nil, false, nil
}
- return nil
-}
-
-func (c *Client) renewAfter(d time.Duration) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
- ctx, cancel := context.WithCancel(context.Background())
- c.cancelRenew = cancel
- go func() {
- timer := time.NewTimer(d)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- case <-timer.C:
- if err := c.Request(ctx, c.addr); err != nil {
- log.Printf("address renewal failed: %v", err)
- go c.renewAfter(1 * time.Minute)
- }
- }
- }()
+ return opts, true, nil
}
diff --git a/pkg/dhcp/dhcp.go b/pkg/dhcp/dhcp.go
index 18c318fc8..ceaba34c3 100644
--- a/pkg/dhcp/dhcp.go
+++ b/pkg/dhcp/dhcp.go
@@ -26,19 +26,21 @@ import (
// Config is standard DHCP configuration.
type Config struct {
- ServerAddress tcpip.Address // address of the server
- SubnetMask tcpip.AddressMask // client address subnet mask
- Gateway tcpip.Address // client default gateway
- DomainNameServer tcpip.Address // client domain name server
- LeaseLength time.Duration // length of the address lease
+ Error error
+ ServerAddress tcpip.Address // address of the server
+ SubnetMask tcpip.AddressMask // client address subnet mask
+ Gateway tcpip.Address // client default gateway
+ DNS []tcpip.Address // client DNS server addresses
+ LeaseLength time.Duration // length of the address lease
}
func (cfg *Config) decode(opts []option) error {
*cfg = Config{}
for _, opt := range opts {
b := opt.body
- if l := opt.code.len(); l != -1 && l != len(b) {
- return fmt.Errorf("%s bad length: %d", opt.code, len(b))
+ if !opt.code.lenValid(len(b)) {
+ // TODO: s/%v/%s/ when `go vet` is smarter.
+ return fmt.Errorf("%v: bad length: %d", opt.code, len(b))
}
switch opt.code {
case optLeaseTime:
@@ -51,7 +53,12 @@ func (cfg *Config) decode(opts []option) error {
case optDefaultGateway:
cfg.Gateway = tcpip.Address(b)
case optDomainNameServer:
- cfg.DomainNameServer = tcpip.Address(b)
+ for ; len(b) > 0; b = b[4:] {
+ if len(b) < 4 {
+ return fmt.Errorf("DNS bad length: %d", len(b))
+ }
+ cfg.DNS = append(cfg.DNS, tcpip.Address(b[:4]))
+ }
}
}
return nil
@@ -67,8 +74,12 @@ func (cfg Config) encode() (opts []option) {
if cfg.Gateway != "" {
opts = append(opts, option{optDefaultGateway, []byte(cfg.Gateway)})
}
- if cfg.DomainNameServer != "" {
- opts = append(opts, option{optDomainNameServer, []byte(cfg.DomainNameServer)})
+ if len(cfg.DNS) > 0 {
+ dns := make([]byte, 0, 4*len(cfg.DNS))
+ for _, addr := range cfg.DNS {
+ dns = append(dns, addr...)
+ }
+ opts = append(opts, option{optDomainNameServer, dns})
}
if l := cfg.LeaseLength / time.Second; l != 0 {
v := make([]byte, 4)
@@ -82,8 +93,10 @@ func (cfg Config) encode() (opts []option) {
}
const (
- serverPort = 67
- clientPort = 68
+ // ServerPort is the well-known UDP port number for a DHCP server.
+ ServerPort = 67
+ // ClientPort is the well-known UDP port number for a DHCP client.
+ ClientPort = 68
)
var magicCookie = []byte{99, 130, 83, 99} // RFC 1497
@@ -107,10 +120,10 @@ func (h header) isValid() bool {
if o := h.op(); o != opRequest && o != opReply {
return false
}
- if h[1] != 0x01 || h[2] != 0x06 || h[3] != 0x00 {
+ if h[1] != 0x01 || h[2] != 0x06 {
return false
}
- return bytes.Equal(h[236:240], magicCookie) && h[len(h)-1] == 0
+ return bytes.Equal(h[236:240], magicCookie)
}
func (h header) op() op { return op(h[0]) }
@@ -141,7 +154,7 @@ func (h header) options() (opts options, err error) {
}
optlen := int(h[i+1])
if len(h) < i+2+optlen {
- return nil, fmt.Errorf("option too long")
+ return nil, fmt.Errorf("option %v too long i=%d, optlen=%d", optionCode(h[i]), i, optlen)
}
opts = append(opts, option{
code: optionCode(h[i]),
@@ -160,6 +173,8 @@ func (h header) setOptions(opts []option) {
copy(h[i+2:i+2+len(opt.body)], opt.body)
i += 2 + len(opt.body)
}
+ h[i] = 255 // End option
+ i++
for ; i < len(h); i++ {
h[i] = 0
}
@@ -182,47 +197,31 @@ const (
optSubnetMask optionCode = 1
optDefaultGateway optionCode = 3
optDomainNameServer optionCode = 6
+ optDomainName optionCode = 15
optReqIPAddr optionCode = 50
optLeaseTime optionCode = 51
optDHCPMsgType optionCode = 53 // dhcpMsgType
optDHCPServer optionCode = 54
optParamReq optionCode = 55
+ optMessage optionCode = 56
+ optClientID optionCode = 61
)
-func (code optionCode) len() int {
+func (code optionCode) lenValid(l int) bool {
switch code {
- case optSubnetMask, optDefaultGateway, optDomainNameServer,
+ case optSubnetMask, optDefaultGateway,
optReqIPAddr, optLeaseTime, optDHCPServer:
- return 4
+ return l == 4
case optDHCPMsgType:
- return 1
- case optParamReq:
- return -1 // no fixed length
- default:
- return -1
- }
-}
-
-func (code optionCode) String() string {
- switch code {
- case optSubnetMask:
- return "option(subnet-mask)"
- case optDefaultGateway:
- return "option(default-gateway)"
+ return l == 1
case optDomainNameServer:
- return "option(dns)"
- case optReqIPAddr:
- return "option(request-ip-address)"
- case optLeaseTime:
- return "option(least-time)"
- case optDHCPMsgType:
- return "option(message-type)"
- case optDHCPServer:
- return "option(server)"
+ return l%4 == 0
+ case optMessage, optDomainName, optClientID:
+ return l >= 1
case optParamReq:
- return "option(parameter-request)"
+ return true // no fixed length
default:
- return fmt.Sprintf("option(%d)", code)
+ return true // unknown option, assume ok
}
}
@@ -232,11 +231,12 @@ func (opts options) dhcpMsgType() (dhcpMsgType, error) {
for _, opt := range opts {
if opt.code == optDHCPMsgType {
if len(opt.body) != 1 {
- return 0, fmt.Errorf("%s: bad length: %d", optDHCPMsgType, len(opt.body))
+ // TODO: s/%v/%s/ when `go vet` is smarter.
+ return 0, fmt.Errorf("%v: bad length: %d", opt.code, len(opt.body))
}
v := opt.body[0]
if v <= 0 || v >= 8 {
- return 0, fmt.Errorf("%s: unknown value: %d", optDHCPMsgType, v)
+ return 0, fmt.Errorf("DHCP bad length: %d", len(opt.body))
}
return dhcpMsgType(v), nil
}
@@ -244,6 +244,15 @@ func (opts options) dhcpMsgType() (dhcpMsgType, error) {
return 0, nil
}
+func (opts options) message() string {
+ for _, opt := range opts {
+ if opt.code == optMessage {
+ return string(opt.body)
+ }
+ }
+ return ""
+}
+
func (opts options) len() int {
l := 0
for _, opt := range opts {
diff --git a/pkg/dhcp/dhcp_string.go b/pkg/dhcp/dhcp_string.go
new file mode 100644
index 000000000..7cabed29e
--- /dev/null
+++ b/pkg/dhcp/dhcp_string.go
@@ -0,0 +1,115 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dhcp
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+func (h header) String() string {
+ opts, err := h.options()
+ var msgtype dhcpMsgType
+ if err == nil {
+ msgtype, err = opts.dhcpMsgType()
+ }
+ if !h.isValid() || err != nil {
+ return fmt.Sprintf("DHCP invalid, %v %v h[1:4]=%x cookie=%x len=%d (%v)", h.op(), h.xid(), []byte(h[1:4]), []byte(h[236:240]), len(h), err)
+ }
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "%v %v len=%d\n", msgtype, h.xid(), len(h))
+ fmt.Fprintf(buf, "\tciaddr:%v yiaddr:%v siaddr:%v giaddr:%v\n",
+ tcpip.Address(h.ciaddr()),
+ tcpip.Address(h.yiaddr()),
+ tcpip.Address(h.siaddr()),
+ tcpip.Address(h.giaddr()))
+ fmt.Fprintf(buf, "\tchaddr:%x", h.chaddr())
+ for _, opt := range opts {
+ fmt.Fprintf(buf, "\n\t%v", opt)
+ }
+ return buf.String()
+}
+
+func (opt option) String() string {
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "%v: ", opt.code)
+ fmt.Fprintf(buf, "%x", opt.body)
+ return buf.String()
+}
+
+func (code optionCode) String() string {
+ switch code {
+ case optSubnetMask:
+ return "option(subnet-mask)"
+ case optDefaultGateway:
+ return "option(default-gateway)"
+ case optDomainNameServer:
+ return "option(dns)"
+ case optDomainName:
+ return "option(domain-name)"
+ case optReqIPAddr:
+ return "option(request-ip-address)"
+ case optLeaseTime:
+ return "option(lease-time)"
+ case optDHCPMsgType:
+ return "option(message-type)"
+ case optDHCPServer:
+ return "option(server)"
+ case optParamReq:
+ return "option(parameter-request)"
+ case optMessage:
+ return "option(message)"
+ case optClientID:
+ return "option(client-id)"
+ default:
+ return fmt.Sprintf("option(%d)", code)
+ }
+}
+
+func (o op) String() string {
+ switch o {
+ case opRequest:
+ return "op(request)"
+ case opReply:
+ return "op(reply)"
+ }
+ return fmt.Sprintf("op(UNKNOWN:%d)", int(o))
+}
+
+func (t dhcpMsgType) String() string {
+ switch t {
+ case dhcpDISCOVER:
+ return "DHCPDISCOVER"
+ case dhcpOFFER:
+ return "DHCPOFFER"
+ case dhcpREQUEST:
+ return "DHCPREQUEST"
+ case dhcpDECLINE:
+ return "DHCPDECLINE"
+ case dhcpACK:
+ return "DHCPACK"
+ case dhcpNAK:
+ return "DHCPNAK"
+ case dhcpRELEASE:
+ return "DHCPRELEASE"
+ }
+ return fmt.Sprintf("DHCP(%d)", int(t))
+}
+
+func (v xid) String() string {
+ return fmt.Sprintf("xid:%x", uint32(v))
+}
diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go
index 731ed61a5..67814683a 100644
--- a/pkg/dhcp/dhcp_test.go
+++ b/pkg/dhcp/dhcp_test.go
@@ -27,9 +27,13 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
)
-func TestDHCP(t *testing.T) {
+const nicid = tcpip.NICID(1)
+const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
+
+func createStack(t *testing.T) *stack.Stack {
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, "")
if testing.Verbose() {
@@ -48,17 +52,9 @@ func TestDHCP(t *testing.T) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
- const nicid tcpip.NICID = 1
if err := s.CreateNIC(nicid, id); err != nil {
t.Fatal(err)
}
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil {
- t.Fatal(err)
- }
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil {
- t.Fatal(err)
- }
- const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil {
t.Fatal(err)
}
@@ -70,31 +66,38 @@ func TestDHCP(t *testing.T) {
NIC: nicid,
}})
- var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
+ return s
+}
+
+func TestDHCP(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
serverCfg := Config{
- ServerAddress: serverAddr,
- SubnetMask: "\xff\xff\xff\x00",
- Gateway: "\xc0\xa8\x03\xF0",
- DomainNameServer: "\x08\x08\x08\x08",
- LeaseLength: 24 * time.Hour,
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{
+ "\x08\x08\x08\x08", "\x08\x08\x04\x04",
+ },
+ LeaseLength: 24 * time.Hour,
}
serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
- _, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
+ _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg)
if err != nil {
t.Fatal(err)
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
- c0 := NewClient(s, nicid, clientLinkAddr0)
- if err := c0.Request(context.Background(), ""); err != nil {
+ c0 := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
@@ -102,22 +105,219 @@ func TestDHCP(t *testing.T) {
}
const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
- c1 := NewClient(s, nicid, clientLinkAddr1)
- if err := c1.Request(context.Background(), ""); err != nil {
+ c1 := NewClient(s, nicid, clientLinkAddr1, nil)
+ if _, err := c1.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c1.Address(), clientAddrs[1]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if got, want := c0.Config(), serverCfg; got != want {
+ if got, want := c0.Config(), serverCfg; !equalConfig(got, want) {
t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want)
}
}
+
+func equalConfig(c0, c1 Config) bool {
+ if c0.Error != c1.Error || c0.ServerAddress != c1.ServerAddress || c0.SubnetMask != c1.SubnetMask || c0.Gateway != c1.Gateway || c0.LeaseLength != c1.LeaseLength {
+ return false
+ }
+ if len(c0.DNS) != len(c1.DNS) {
+ return false
+ }
+ for i := 0; i < len(c0.DNS); i++ {
+ if c0.DNS[i] != c1.DNS[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func TestRenew(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02"}
+
+ serverCfg := Config{
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 1 * time.Second,
+ }
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ count := 0
+ var curAddr tcpip.Address
+ addrCh := make(chan tcpip.Address)
+ acquiredFunc := func(oldAddr, newAddr tcpip.Address, cfg Config) {
+ if err := cfg.Error; err != nil {
+ t.Fatalf("acquisition %d failed: %v", count, err)
+ }
+ if oldAddr != curAddr {
+ t.Fatalf("aquisition %d: curAddr=%v, oldAddr=%v", count, curAddr, oldAddr)
+ }
+ if cfg.LeaseLength != time.Second {
+ t.Fatalf("aquisition %d: lease length: %v, want %v", count, cfg.LeaseLength, time.Second)
+ }
+ count++
+ curAddr = newAddr
+ addrCh <- newAddr
+ }
+
+ clientCtx, cancel := context.WithCancel(context.Background())
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc)
+ c.Run(clientCtx)
+
+ var addr tcpip.Address
+ select {
+ case addr = <-addrCh:
+ t.Logf("got first address: %v", addr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout acquiring initial address")
+ }
+
+ select {
+ case newAddr := <-addrCh:
+ t.Logf("got renewal: %v", newAddr)
+ if newAddr != addr {
+ t.Fatalf("renewal address is %v, want %v", newAddr, addr)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for address renewal")
+ }
+
+ cancel()
+}
+
+// Regression test for https://fuchsia.atlassian.net/browse/NET-17
+func TestNoNullTerminator(t *testing.T) {
+ v := "\x02\x01\x06\x00" +
+ "\xc8\x37\xbe\x73\x00\x00\x80\x00\x00\x00\x00\x00\xc0\xa8\x2b\x92" +
+ "\xc0\xa8\x2b\x01\x00\x00\x00\x00\x00\x0f\x60\x0a\x23\x93\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x63\x82\x53\x63\x35\x01\x02\x36" +
+ "\x04\xc0\xa8\x2b\x01\x33\x04\x00\x00\x0e\x10\x3a\x04\x00\x00\x07" +
+ "\x08\x3b\x04\x00\x00\x0c\x4e\x01\x04\xff\xff\xff\x00\x1c\x04\xc0" +
+ "\xa8\x2b\xff\x03\x04\xc0\xa8\x2b\x01\x06\x04\xc0\xa8\x2b\x01\x2b" +
+ "\x0f\x41\x4e\x44\x52\x4f\x49\x44\x5f\x4d\x45\x54\x45\x52\x45\x44" +
+ "\xff"
+ h := header(v)
+ if !h.isValid() {
+ t.Error("failed to decode header")
+ }
+
+ if got, want := h.op(), opReply; got != want {
+ t.Errorf("h.op()=%v, want=%v", got, want)
+ }
+
+ if _, err := h.options(); err != nil {
+ t.Errorf("bad options: %v", err)
+ }
+}
+
+func teeConn(c conn) (conn, conn) {
+ dup1 := &dupConn{
+ c: c,
+ dup: make(chan connMsg, 8),
+ }
+ dup2 := &chConn{
+ c: c,
+ ch: dup1.dup,
+ }
+ return dup1, dup2
+}
+
+type connMsg struct {
+ buf buffer.View
+ addr tcpip.FullAddress
+ err error
+}
+
+type dupConn struct {
+ c conn
+ dup chan connMsg
+}
+
+func (c *dupConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ v, addr, err := c.c.Read()
+ c.dup <- connMsg{v, addr, err}
+ return v, addr, err
+}
+func (c *dupConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+type chConn struct {
+ ch chan connMsg
+ c conn
+}
+
+func (c *chConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ msg := <-c.ch
+ return msg.buf, msg.addr, msg.err
+}
+func (c *chConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+func TestTwoServers(t *testing.T) {
+ s := createStack(t)
+
+ wq := new(waiter.Queue)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("dhcp: server endpoint: %v", err)
+ }
+ if err = ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
+ t.Fatalf("dhcp: server bind: %v", err)
+ }
+
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ c1, c2 := teeConn(newEPConn(serverCtx, wq, ep))
+
+ if _, err := NewServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x03\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ }); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := NewServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x04\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go
index 0beac7782..003e272b2 100644
--- a/pkg/dhcp/server.go
+++ b/pkg/dhcp/server.go
@@ -17,11 +17,13 @@ package dhcp
import (
"context"
"fmt"
+ "io"
"log"
"sync"
"time"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
@@ -30,10 +32,8 @@ import (
// Server is a DHCP server.
type Server struct {
- stack *stack.Stack
+ conn conn
broadcast tcpip.FullAddress
- wq waiter.Queue
- ep tcpip.Endpoint
addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
cfg Config
cfgopts []option // cfg to send to client
@@ -44,36 +44,96 @@ type Server struct {
leases map[tcpip.LinkAddress]serverLease
}
+// conn is a blocking read/write network endpoint.
+type conn interface {
+ Read() (buffer.View, tcpip.FullAddress, error)
+ Write([]byte, *tcpip.FullAddress) error
+}
+
+type epConn struct {
+ ctx context.Context
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+ we waiter.Entry
+ inCh chan struct{}
+}
+
+func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn {
+ c := &epConn{
+ ctx: ctx,
+ wq: wq,
+ ep: ep,
+ }
+ c.we, c.inCh = waiter.NewChannelEntry(nil)
+ wq.EventRegister(&c.we, waiter.EventIn)
+
+ go func() {
+ <-ctx.Done()
+ wq.EventUnregister(&c.we)
+ }()
+
+ return c
+}
+
+func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ for {
+ var addr tcpip.FullAddress
+ v, _, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-c.inCh:
+ continue
+ case <-c.ctx.Done():
+ return nil, tcpip.FullAddress{}, io.EOF
+ }
+ }
+ if err != nil {
+ return v, addr, fmt.Errorf("read: %v", err)
+ }
+ return v, addr, nil
+ }
+}
+
+func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
+ if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
+ return fmt.Errorf("write: %v", err)
+ }
+ return nil
+}
+
+func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ wq := new(waiter.Queue)
+ ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
+ return nil, fmt.Errorf("dhcp: server bind: %v", err)
+ }
+ c := newEPConn(ctx, wq, ep)
+ return NewServer(ctx, c, addrs, cfg)
+}
+
// NewServer creates a new DHCP server and begins serving.
// The server continues serving until ctx is done.
-func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+func NewServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ if cfg.ServerAddress == "" {
+ return nil, fmt.Errorf("dhcp: server requires explicit server address")
+ }
s := &Server{
- stack: stack,
+ conn: c,
addrs: addrs,
cfg: cfg,
cfgopts: cfg.encode(),
broadcast: tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
- Port: clientPort,
+ Port: ClientPort,
},
handlers: make([]chan header, 8),
leases: make(map[tcpip.LinkAddress]serverLease),
}
- var err *tcpip.Error
- s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq)
- if err != nil {
- return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
- }
- serverBroadcast := tcpip.FullAddress{
- Addr: "",
- Port: serverPort,
- }
- if err := s.ep.Bind(serverBroadcast, nil); err != nil {
- return nil, fmt.Errorf("dhcp: server bind: %v", err)
- }
-
for i := 0; i < len(s.handlers); i++ {
ch := make(chan header, 8)
s.handlers[i] = ch
@@ -108,20 +168,10 @@ func (s *Server) expirer(ctx context.Context) {
// reader listens for all incoming DHCP packets and fans them out to
// handling goroutines based on XID as session identifiers.
func (s *Server) reader(ctx context.Context) {
- we, ch := waiter.NewChannelEntry(nil)
- s.wq.EventRegister(&we, waiter.EventIn)
- defer s.wq.EventUnregister(&we)
-
for {
- var addr tcpip.FullAddress
- v, _, err := s.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- select {
- case <-ch:
- continue
- case <-ctx.Done():
- return
- }
+ v, _, err := s.conn.Read()
+ if err != nil {
+ return
}
h := header(v)
@@ -234,21 +284,50 @@ func (s *Server) handleDiscover(hreq header, opts options) {
// DHCPOFFER
opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}}
opts = append(opts, s.cfgopts...)
- h := make(header, headerBaseSize+opts.len())
+ h := make(header, headerBaseSize+opts.len()+1)
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+ s.conn.Write([]byte(h), &s.broadcast)
+}
+
+func (s *Server) nack(hreq header) {
+ // DHCPNACK
+ opts := options([]option{
+ {optDHCPMsgType, []byte{byte(dhcpNAK)}},
+ {optDHCPServer, []byte(s.cfg.ServerAddress)},
+ })
+ h := make(header, headerBaseSize+opts.len()+1)
+ h.init()
+ h.setOp(opReply)
+ copy(h.xidbytes(), hreq.xidbytes())
+ copy(h.chaddr(), hreq.chaddr())
+ h.setOptions(opts)
+ s.conn.Write([]byte(h), &s.broadcast)
}
func (s *Server) handleRequest(hreq header, opts options) {
linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
xid := hreq.xid()
+ reqopts, err := hreq.options()
+ if err != nil {
+ s.nack(hreq)
+ return
+ }
+ var reqcfg Config
+ if err := reqcfg.decode(reqopts); err != nil {
+ s.nack(hreq)
+ return
+ }
+ if reqcfg.ServerAddress != s.cfg.ServerAddress {
+ // This request is for a different DHCP server. Ignore it.
+ return
+ }
+
s.mu.Lock()
lease := s.leases[linkAddr]
switch lease.state {
@@ -271,15 +350,14 @@ func (s *Server) handleRequest(hreq header, opts options) {
// DHCPACK
opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}}
opts = append(opts, s.cfgopts...)
- h := make(header, headerBaseSize+opts.len())
+ h := make(header, headerBaseSize+opts.len()+1)
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+ s.conn.Write([]byte(h), &s.broadcast)
}
type leaseState int
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cc5427cf9..2c8c4aa31 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -764,9 +764,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
- // TODO: provide a way for a
- // transport endpoint to receive a signal that AddLinkAddress
- // for a particular address has been called.
+ // TODO: provide a way for a transport endpoint to receive a signal
+ // that AddLinkAddress for a particular address has been called.
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.