diff options
Diffstat (limited to 'pkg/dhcp')
-rw-r--r-- | pkg/dhcp/BUILD | 3 | ||||
-rw-r--r-- | pkg/dhcp/client.go | 285 | ||||
-rw-r--r-- | pkg/dhcp/dhcp.go | 99 | ||||
-rw-r--r-- | pkg/dhcp/dhcp_string.go | 115 | ||||
-rw-r--r-- | pkg/dhcp/dhcp_test.go | 246 | ||||
-rw-r--r-- | pkg/dhcp/server.go | 154 |
6 files changed, 685 insertions, 217 deletions
diff --git a/pkg/dhcp/BUILD b/pkg/dhcp/BUILD index bd9f592b4..711a72c99 100644 --- a/pkg/dhcp/BUILD +++ b/pkg/dhcp/BUILD @@ -7,12 +7,14 @@ go_library( srcs = [ "client.go", "dhcp.go", + "dhcp_string.go", "server.go", ], importpath = "gvisor.googlesource.com/gvisor/pkg/dhcp", deps = [ "//pkg/rand", "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/transport/udp", @@ -33,5 +35,6 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/transport/udp", + "//pkg/waiter", ], ) diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go index 8b5fc0452..909040e79 100644 --- a/pkg/dhcp/client.go +++ b/pkg/dhcp/client.go @@ -18,7 +18,6 @@ import ( "bytes" "context" "fmt" - "log" "sync" "time" @@ -32,9 +31,10 @@ import ( // Client is a DHCP client. type Client struct { - stack *stack.Stack - nicid tcpip.NICID - linkAddr tcpip.LinkAddress + stack *stack.Stack + nicid tcpip.NICID + linkAddr tcpip.LinkAddress + acquiredFunc func(old, new tcpip.Address, cfg Config) mu sync.Mutex addr tcpip.Address @@ -46,29 +46,57 @@ type Client struct { // NewClient creates a DHCP client. // // TODO: add s.LinkAddr(nicid) to *stack.Stack. -func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client { +func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client { return &Client{ - stack: s, - nicid: nicid, - linkAddr: linkAddr, + stack: s, + nicid: nicid, + linkAddr: linkAddr, + acquiredFunc: acquiredFunc, } } -// Start starts the DHCP client. +// Run starts the DHCP client. // It will periodically search for an IP address using the Request method. -func (c *Client) Start() { - go func() { - for { - log.Print("DHCP request") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - err := c.Request(ctx, "") - cancel() - if err == nil { - break - } +func (c *Client) Run(ctx context.Context) { + go c.run(ctx) +} + +func (c *Client) run(ctx context.Context) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.addr != "" { + c.stack.RemoveAddress(c.nicid, c.addr) } - log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength) }() + + var renewAddr tcpip.Address + for { + reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + cfg, err := c.Request(reqCtx, renewAddr) + cancel() + if err != nil { + select { + case <-time.After(1 * time.Second): + // loop and try again + case <-ctx.Done(): + return + } + } + + c.mu.Lock() + renewAddr = c.addr + c.mu.Unlock() + + timer := time.NewTimer(cfg.LeaseLength) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + // loop and make a renewal request + } + } } // Address reports the IP address acquired by the DHCP client. @@ -85,56 +113,53 @@ func (c *Client) Config() Config { return c.cfg } -// Shutdown relinquishes any lease and ends any outstanding renewal timers. -func (c *Client) Shutdown() { - c.mu.Lock() - defer c.mu.Unlock() - if c.addr != "" { - c.stack.RemoveAddress(c.nicid, c.addr) - } - if c.cancelRenew != nil { - c.cancelRenew() - } -} - // Request executes a DHCP request session. // // On success, it adds a new address to this client's TCPIP stack. // If the server sets a lease limit a timer is set to automatically // renew it. -func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error { +func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) { + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress { + return Config{}, fmt.Errorf("dhcp: %v", err) + } + if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress { + return Config{}, fmt.Errorf("dhcp: %v", err) + } + defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff") + defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00") + var wq waiter.Queue ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - return fmt.Errorf("dhcp: outbound endpoint: %v", err) + return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err) } - err = ep.Bind(tcpip.FullAddress{ - Addr: "\x00\x00\x00\x00", - Port: clientPort, - }, nil) defer ep.Close() - if err != nil { - return fmt.Errorf("dhcp: connect failed: %v", err) + if err := ep.Bind(tcpip.FullAddress{ + Addr: "\x00\x00\x00\x00", + Port: ClientPort, + NIC: c.nicid, + }, nil); err != nil { + return Config{}, fmt.Errorf("dhcp: connect failed: %v", err) } epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - return fmt.Errorf("dhcp: inbound endpoint: %v", err) + return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err) } - err = epin.Bind(tcpip.FullAddress{ - Addr: "\xff\xff\xff\xff", - Port: clientPort, - }, nil) defer epin.Close() - if err != nil { - return fmt.Errorf("dhcp: connect failed: %v", err) + if err := epin.Bind(tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: ClientPort, + NIC: c.nicid, + }, nil); err != nil { + return Config{}, fmt.Errorf("dhcp: connect failed: %v", err) } var xid [4]byte rand.Read(xid[:]) // DHCPDISCOVERY - options := options{ + discOpts := options{ {optDHCPMsgType, []byte{byte(dhcpDISCOVER)}}, {optParamReq, []byte{ 1, // request subnet mask @@ -144,25 +169,34 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error }}, } if requestedAddr != "" { - options = append(options, option{optReqIPAddr, []byte(requestedAddr)}) + discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)}) } - h := make(header, headerBaseSize+options.len()) + var clientID []byte + if len(c.linkAddr) == 6 { + clientID = append( + []byte{1}, // RFC 1700: Hardware Type [Ethernet = 1] + c.linkAddr..., + ) + discOpts = append(discOpts, option{optClientID, clientID}) + } + h := make(header, headerBaseSize+discOpts.len()+1) h.init() h.setOp(opRequest) copy(h.xidbytes(), xid[:]) h.setBroadcast() copy(h.chaddr(), c.linkAddr) - h.setOptions(options) + h.setOptions(discOpts) serverAddr := &tcpip.FullAddress{ Addr: "\xff\xff\xff\xff", - Port: serverPort, + Port: ServerPort, + NIC: c.nicid, } wopts := tcpip.WriteOptions{ To: serverAddr, } if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { - return fmt.Errorf("dhcp discovery write: %v", err) + return Config{}, fmt.Errorf("dhcp discovery write: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -170,6 +204,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error defer wq.EventUnregister(&we) // DHCPOFFER + var opts options for { var addr tcpip.FullAddress v, _, err := epin.Read(&addr) @@ -178,49 +213,84 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error case <-ch: continue case <-ctx.Done(): - return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) + return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted) } } h = header(v) - if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { - break + var valid bool + var e error + opts, valid, e = loadDHCPReply(h, dhcpOFFER, xid[:]) + if !valid { + if e != nil { + // TODO: handle all the errors? + // TODO: report malformed server responses + } + continue } - } - if _, err := h.options(); err != nil { - return fmt.Errorf("dhcp offer: %v", err) + break } var ack bool - var cfg Config + if err := cfg.decode(opts); err != nil { + return Config{}, fmt.Errorf("dhcp offer: %v", err) + } // DHCPREQUEST addr := tcpip.Address(h.yiaddr()) if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil { if err != tcpip.ErrDuplicateAddress { - return fmt.Errorf("adding address: %v", err) + return Config{}, fmt.Errorf("adding address: %v", err) } } defer func() { - if ack { - c.mu.Lock() - c.addr = addr - c.cfg = cfg - c.mu.Unlock() - } else { + if !ack || reterr != nil { c.stack.RemoveAddress(c.nicid, addr) + addr = "" + cfg = Config{Error: reterr} + } + + c.mu.Lock() + oldAddr := c.addr + c.addr = addr + c.cfg = cfg + c.mu.Unlock() + + // Clean up broadcast addresses before calling acquiredFunc + // so nothing else uses them by mistake. + // + // (The deferred RemoveAddress calls above silently error.) + c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff") + c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00") + + if c.acquiredFunc != nil { + c.acquiredFunc(oldAddr, addr, cfg) + } + if requestedAddr != "" && requestedAddr != addr { + c.stack.RemoveAddress(c.nicid, requestedAddr) } }() + h.init() h.setOp(opRequest) for i, b := 0, h.yiaddr(); i < len(b); i++ { b[i] = 0 } - h.setOptions([]option{ + for i, b := 0, h.siaddr(); i < len(b); i++ { + b[i] = 0 + } + for i, b := 0, h.giaddr(); i < len(b); i++ { + b[i] = 0 + } + reqOpts := []option{ {optDHCPMsgType, []byte{byte(dhcpREQUEST)}}, {optReqIPAddr, []byte(addr)}, - {optDHCPServer, h.siaddr()}, - }) + {optDHCPServer, []byte(cfg.ServerAddress)}, + } + if len(clientID) != 0 { + reqOpts = append(reqOpts, option{optClientID, clientID}) + } + h.setOptions(reqOpts) if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil { - return fmt.Errorf("dhcp discovery write: %v", err) + return Config{}, fmt.Errorf("dhcp discovery write: %v", err) } // DHCPACK @@ -232,53 +302,46 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error case <-ch: continue case <-ctx.Done(): - return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) + return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted) } } h = header(v) - if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) { - break + var valid bool + var e error + opts, valid, e = loadDHCPReply(h, dhcpACK, xid[:]) + if !valid { + if e != nil { + // TODO: handle all the errors? + // TODO: report malformed server responses + } + if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid { + if msg := opts.message(); msg != "" { + return Config{}, fmt.Errorf("dhcp: NAK %q", msg) + } + return Config{}, fmt.Errorf("dhcp: NAK with no message") + } + continue } + break } - opts, e := h.options() - if e != nil { - return fmt.Errorf("dhcp ack: %v", e) - } - if err := cfg.decode(opts); err != nil { - return fmt.Errorf("dhcp ack bad options: %v", err) + ack = true + return cfg, nil +} + +func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) { + if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) { + return nil, false, nil } - msgtype, e := opts.dhcpMsgType() - if e != nil { - return fmt.Errorf("dhcp ack: %v", e) + opts, err = h.options() + if err != nil { + return nil, false, err } - ack = msgtype == dhcpACK - if !ack { - return fmt.Errorf("dhcp: request not acknowledged") + msgtype, err := opts.dhcpMsgType() + if err != nil { + return nil, false, err } - if cfg.LeaseLength != 0 { - go c.renewAfter(cfg.LeaseLength) + if msgtype != typ { + return nil, false, nil } - return nil -} - -func (c *Client) renewAfter(d time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - if c.cancelRenew != nil { - c.cancelRenew() - } - ctx, cancel := context.WithCancel(context.Background()) - c.cancelRenew = cancel - go func() { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - case <-timer.C: - if err := c.Request(ctx, c.addr); err != nil { - log.Printf("address renewal failed: %v", err) - go c.renewAfter(1 * time.Minute) - } - } - }() + return opts, true, nil } diff --git a/pkg/dhcp/dhcp.go b/pkg/dhcp/dhcp.go index 18c318fc8..ceaba34c3 100644 --- a/pkg/dhcp/dhcp.go +++ b/pkg/dhcp/dhcp.go @@ -26,19 +26,21 @@ import ( // Config is standard DHCP configuration. type Config struct { - ServerAddress tcpip.Address // address of the server - SubnetMask tcpip.AddressMask // client address subnet mask - Gateway tcpip.Address // client default gateway - DomainNameServer tcpip.Address // client domain name server - LeaseLength time.Duration // length of the address lease + Error error + ServerAddress tcpip.Address // address of the server + SubnetMask tcpip.AddressMask // client address subnet mask + Gateway tcpip.Address // client default gateway + DNS []tcpip.Address // client DNS server addresses + LeaseLength time.Duration // length of the address lease } func (cfg *Config) decode(opts []option) error { *cfg = Config{} for _, opt := range opts { b := opt.body - if l := opt.code.len(); l != -1 && l != len(b) { - return fmt.Errorf("%s bad length: %d", opt.code, len(b)) + if !opt.code.lenValid(len(b)) { + // TODO: s/%v/%s/ when `go vet` is smarter. + return fmt.Errorf("%v: bad length: %d", opt.code, len(b)) } switch opt.code { case optLeaseTime: @@ -51,7 +53,12 @@ func (cfg *Config) decode(opts []option) error { case optDefaultGateway: cfg.Gateway = tcpip.Address(b) case optDomainNameServer: - cfg.DomainNameServer = tcpip.Address(b) + for ; len(b) > 0; b = b[4:] { + if len(b) < 4 { + return fmt.Errorf("DNS bad length: %d", len(b)) + } + cfg.DNS = append(cfg.DNS, tcpip.Address(b[:4])) + } } } return nil @@ -67,8 +74,12 @@ func (cfg Config) encode() (opts []option) { if cfg.Gateway != "" { opts = append(opts, option{optDefaultGateway, []byte(cfg.Gateway)}) } - if cfg.DomainNameServer != "" { - opts = append(opts, option{optDomainNameServer, []byte(cfg.DomainNameServer)}) + if len(cfg.DNS) > 0 { + dns := make([]byte, 0, 4*len(cfg.DNS)) + for _, addr := range cfg.DNS { + dns = append(dns, addr...) + } + opts = append(opts, option{optDomainNameServer, dns}) } if l := cfg.LeaseLength / time.Second; l != 0 { v := make([]byte, 4) @@ -82,8 +93,10 @@ func (cfg Config) encode() (opts []option) { } const ( - serverPort = 67 - clientPort = 68 + // ServerPort is the well-known UDP port number for a DHCP server. + ServerPort = 67 + // ClientPort is the well-known UDP port number for a DHCP client. + ClientPort = 68 ) var magicCookie = []byte{99, 130, 83, 99} // RFC 1497 @@ -107,10 +120,10 @@ func (h header) isValid() bool { if o := h.op(); o != opRequest && o != opReply { return false } - if h[1] != 0x01 || h[2] != 0x06 || h[3] != 0x00 { + if h[1] != 0x01 || h[2] != 0x06 { return false } - return bytes.Equal(h[236:240], magicCookie) && h[len(h)-1] == 0 + return bytes.Equal(h[236:240], magicCookie) } func (h header) op() op { return op(h[0]) } @@ -141,7 +154,7 @@ func (h header) options() (opts options, err error) { } optlen := int(h[i+1]) if len(h) < i+2+optlen { - return nil, fmt.Errorf("option too long") + return nil, fmt.Errorf("option %v too long i=%d, optlen=%d", optionCode(h[i]), i, optlen) } opts = append(opts, option{ code: optionCode(h[i]), @@ -160,6 +173,8 @@ func (h header) setOptions(opts []option) { copy(h[i+2:i+2+len(opt.body)], opt.body) i += 2 + len(opt.body) } + h[i] = 255 // End option + i++ for ; i < len(h); i++ { h[i] = 0 } @@ -182,47 +197,31 @@ const ( optSubnetMask optionCode = 1 optDefaultGateway optionCode = 3 optDomainNameServer optionCode = 6 + optDomainName optionCode = 15 optReqIPAddr optionCode = 50 optLeaseTime optionCode = 51 optDHCPMsgType optionCode = 53 // dhcpMsgType optDHCPServer optionCode = 54 optParamReq optionCode = 55 + optMessage optionCode = 56 + optClientID optionCode = 61 ) -func (code optionCode) len() int { +func (code optionCode) lenValid(l int) bool { switch code { - case optSubnetMask, optDefaultGateway, optDomainNameServer, + case optSubnetMask, optDefaultGateway, optReqIPAddr, optLeaseTime, optDHCPServer: - return 4 + return l == 4 case optDHCPMsgType: - return 1 - case optParamReq: - return -1 // no fixed length - default: - return -1 - } -} - -func (code optionCode) String() string { - switch code { - case optSubnetMask: - return "option(subnet-mask)" - case optDefaultGateway: - return "option(default-gateway)" + return l == 1 case optDomainNameServer: - return "option(dns)" - case optReqIPAddr: - return "option(request-ip-address)" - case optLeaseTime: - return "option(least-time)" - case optDHCPMsgType: - return "option(message-type)" - case optDHCPServer: - return "option(server)" + return l%4 == 0 + case optMessage, optDomainName, optClientID: + return l >= 1 case optParamReq: - return "option(parameter-request)" + return true // no fixed length default: - return fmt.Sprintf("option(%d)", code) + return true // unknown option, assume ok } } @@ -232,11 +231,12 @@ func (opts options) dhcpMsgType() (dhcpMsgType, error) { for _, opt := range opts { if opt.code == optDHCPMsgType { if len(opt.body) != 1 { - return 0, fmt.Errorf("%s: bad length: %d", optDHCPMsgType, len(opt.body)) + // TODO: s/%v/%s/ when `go vet` is smarter. + return 0, fmt.Errorf("%v: bad length: %d", opt.code, len(opt.body)) } v := opt.body[0] if v <= 0 || v >= 8 { - return 0, fmt.Errorf("%s: unknown value: %d", optDHCPMsgType, v) + return 0, fmt.Errorf("DHCP bad length: %d", len(opt.body)) } return dhcpMsgType(v), nil } @@ -244,6 +244,15 @@ func (opts options) dhcpMsgType() (dhcpMsgType, error) { return 0, nil } +func (opts options) message() string { + for _, opt := range opts { + if opt.code == optMessage { + return string(opt.body) + } + } + return "" +} + func (opts options) len() int { l := 0 for _, opt := range opts { diff --git a/pkg/dhcp/dhcp_string.go b/pkg/dhcp/dhcp_string.go new file mode 100644 index 000000000..7cabed29e --- /dev/null +++ b/pkg/dhcp/dhcp_string.go @@ -0,0 +1,115 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dhcp + +import ( + "bytes" + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +func (h header) String() string { + opts, err := h.options() + var msgtype dhcpMsgType + if err == nil { + msgtype, err = opts.dhcpMsgType() + } + if !h.isValid() || err != nil { + return fmt.Sprintf("DHCP invalid, %v %v h[1:4]=%x cookie=%x len=%d (%v)", h.op(), h.xid(), []byte(h[1:4]), []byte(h[236:240]), len(h), err) + } + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "%v %v len=%d\n", msgtype, h.xid(), len(h)) + fmt.Fprintf(buf, "\tciaddr:%v yiaddr:%v siaddr:%v giaddr:%v\n", + tcpip.Address(h.ciaddr()), + tcpip.Address(h.yiaddr()), + tcpip.Address(h.siaddr()), + tcpip.Address(h.giaddr())) + fmt.Fprintf(buf, "\tchaddr:%x", h.chaddr()) + for _, opt := range opts { + fmt.Fprintf(buf, "\n\t%v", opt) + } + return buf.String() +} + +func (opt option) String() string { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "%v: ", opt.code) + fmt.Fprintf(buf, "%x", opt.body) + return buf.String() +} + +func (code optionCode) String() string { + switch code { + case optSubnetMask: + return "option(subnet-mask)" + case optDefaultGateway: + return "option(default-gateway)" + case optDomainNameServer: + return "option(dns)" + case optDomainName: + return "option(domain-name)" + case optReqIPAddr: + return "option(request-ip-address)" + case optLeaseTime: + return "option(lease-time)" + case optDHCPMsgType: + return "option(message-type)" + case optDHCPServer: + return "option(server)" + case optParamReq: + return "option(parameter-request)" + case optMessage: + return "option(message)" + case optClientID: + return "option(client-id)" + default: + return fmt.Sprintf("option(%d)", code) + } +} + +func (o op) String() string { + switch o { + case opRequest: + return "op(request)" + case opReply: + return "op(reply)" + } + return fmt.Sprintf("op(UNKNOWN:%d)", int(o)) +} + +func (t dhcpMsgType) String() string { + switch t { + case dhcpDISCOVER: + return "DHCPDISCOVER" + case dhcpOFFER: + return "DHCPOFFER" + case dhcpREQUEST: + return "DHCPREQUEST" + case dhcpDECLINE: + return "DHCPDECLINE" + case dhcpACK: + return "DHCPACK" + case dhcpNAK: + return "DHCPNAK" + case dhcpRELEASE: + return "DHCPRELEASE" + } + return fmt.Sprintf("DHCP(%d)", int(t)) +} + +func (v xid) String() string { + return fmt.Sprintf("xid:%x", uint32(v)) +} diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go index 731ed61a5..67814683a 100644 --- a/pkg/dhcp/dhcp_test.go +++ b/pkg/dhcp/dhcp_test.go @@ -27,9 +27,13 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" ) -func TestDHCP(t *testing.T) { +const nicid = tcpip.NICID(1) +const serverAddr = tcpip.Address("\xc0\xa8\x03\x01") + +func createStack(t *testing.T) *stack.Stack { const defaultMTU = 65536 id, linkEP := channel.New(256, defaultMTU, "") if testing.Verbose() { @@ -48,17 +52,9 @@ func TestDHCP(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) - const nicid tcpip.NICID = 1 if err := s.CreateNIC(nicid, id); err != nil { t.Fatal(err) } - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil { - t.Fatal(err) - } - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil { - t.Fatal(err) - } - const serverAddr = tcpip.Address("\xc0\xa8\x03\x01") if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil { t.Fatal(err) } @@ -70,31 +66,38 @@ func TestDHCP(t *testing.T) { NIC: nicid, }}) - var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"} + return s +} + +func TestDHCP(t *testing.T) { + s := createStack(t) + clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"} serverCfg := Config{ - ServerAddress: serverAddr, - SubnetMask: "\xff\xff\xff\x00", - Gateway: "\xc0\xa8\x03\xF0", - DomainNameServer: "\x08\x08\x08\x08", - LeaseLength: 24 * time.Hour, + ServerAddress: serverAddr, + SubnetMask: "\xff\xff\xff\x00", + Gateway: "\xc0\xa8\x03\xF0", + DNS: []tcpip.Address{ + "\x08\x08\x08\x08", "\x08\x08\x04\x04", + }, + LeaseLength: 24 * time.Hour, } serverCtx, cancel := context.WithCancel(context.Background()) defer cancel() - _, err := NewServer(serverCtx, s, clientAddrs, serverCfg) + _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg) if err != nil { t.Fatal(err) } const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52") - c0 := NewClient(s, nicid, clientLinkAddr0) - if err := c0.Request(context.Background(), ""); err != nil { + c0 := NewClient(s, nicid, clientLinkAddr0, nil) + if _, err := c0.Request(context.Background(), ""); err != nil { t.Fatal(err) } if got, want := c0.Address(), clientAddrs[0]; got != want { t.Errorf("c.Addr()=%s, want=%s", got, want) } - if err := c0.Request(context.Background(), ""); err != nil { + if _, err := c0.Request(context.Background(), ""); err != nil { t.Fatal(err) } if got, want := c0.Address(), clientAddrs[0]; got != want { @@ -102,22 +105,219 @@ func TestDHCP(t *testing.T) { } const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53") - c1 := NewClient(s, nicid, clientLinkAddr1) - if err := c1.Request(context.Background(), ""); err != nil { + c1 := NewClient(s, nicid, clientLinkAddr1, nil) + if _, err := c1.Request(context.Background(), ""); err != nil { t.Fatal(err) } if got, want := c1.Address(), clientAddrs[1]; got != want { t.Errorf("c.Addr()=%s, want=%s", got, want) } - if err := c0.Request(context.Background(), ""); err != nil { + if _, err := c0.Request(context.Background(), ""); err != nil { t.Fatal(err) } if got, want := c0.Address(), clientAddrs[0]; got != want { t.Errorf("c.Addr()=%s, want=%s", got, want) } - if got, want := c0.Config(), serverCfg; got != want { + if got, want := c0.Config(), serverCfg; !equalConfig(got, want) { t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want) } } + +func equalConfig(c0, c1 Config) bool { + if c0.Error != c1.Error || c0.ServerAddress != c1.ServerAddress || c0.SubnetMask != c1.SubnetMask || c0.Gateway != c1.Gateway || c0.LeaseLength != c1.LeaseLength { + return false + } + if len(c0.DNS) != len(c1.DNS) { + return false + } + for i := 0; i < len(c0.DNS); i++ { + if c0.DNS[i] != c1.DNS[i] { + return false + } + } + return true +} + +func TestRenew(t *testing.T) { + s := createStack(t) + clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02"} + + serverCfg := Config{ + ServerAddress: serverAddr, + SubnetMask: "\xff\xff\xff\x00", + Gateway: "\xc0\xa8\x03\xF0", + DNS: []tcpip.Address{"\x08\x08\x08\x08"}, + LeaseLength: 1 * time.Second, + } + serverCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := newEPConnServer(serverCtx, s, clientAddrs, serverCfg) + if err != nil { + t.Fatal(err) + } + + count := 0 + var curAddr tcpip.Address + addrCh := make(chan tcpip.Address) + acquiredFunc := func(oldAddr, newAddr tcpip.Address, cfg Config) { + if err := cfg.Error; err != nil { + t.Fatalf("acquisition %d failed: %v", count, err) + } + if oldAddr != curAddr { + t.Fatalf("aquisition %d: curAddr=%v, oldAddr=%v", count, curAddr, oldAddr) + } + if cfg.LeaseLength != time.Second { + t.Fatalf("aquisition %d: lease length: %v, want %v", count, cfg.LeaseLength, time.Second) + } + count++ + curAddr = newAddr + addrCh <- newAddr + } + + clientCtx, cancel := context.WithCancel(context.Background()) + const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52") + c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc) + c.Run(clientCtx) + + var addr tcpip.Address + select { + case addr = <-addrCh: + t.Logf("got first address: %v", addr) + case <-time.After(5 * time.Second): + t.Fatal("timeout acquiring initial address") + } + + select { + case newAddr := <-addrCh: + t.Logf("got renewal: %v", newAddr) + if newAddr != addr { + t.Fatalf("renewal address is %v, want %v", newAddr, addr) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for address renewal") + } + + cancel() +} + +// Regression test for https://fuchsia.atlassian.net/browse/NET-17 +func TestNoNullTerminator(t *testing.T) { + v := "\x02\x01\x06\x00" + + "\xc8\x37\xbe\x73\x00\x00\x80\x00\x00\x00\x00\x00\xc0\xa8\x2b\x92" + + "\xc0\xa8\x2b\x01\x00\x00\x00\x00\x00\x0f\x60\x0a\x23\x93\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x63\x82\x53\x63\x35\x01\x02\x36" + + "\x04\xc0\xa8\x2b\x01\x33\x04\x00\x00\x0e\x10\x3a\x04\x00\x00\x07" + + "\x08\x3b\x04\x00\x00\x0c\x4e\x01\x04\xff\xff\xff\x00\x1c\x04\xc0" + + "\xa8\x2b\xff\x03\x04\xc0\xa8\x2b\x01\x06\x04\xc0\xa8\x2b\x01\x2b" + + "\x0f\x41\x4e\x44\x52\x4f\x49\x44\x5f\x4d\x45\x54\x45\x52\x45\x44" + + "\xff" + h := header(v) + if !h.isValid() { + t.Error("failed to decode header") + } + + if got, want := h.op(), opReply; got != want { + t.Errorf("h.op()=%v, want=%v", got, want) + } + + if _, err := h.options(); err != nil { + t.Errorf("bad options: %v", err) + } +} + +func teeConn(c conn) (conn, conn) { + dup1 := &dupConn{ + c: c, + dup: make(chan connMsg, 8), + } + dup2 := &chConn{ + c: c, + ch: dup1.dup, + } + return dup1, dup2 +} + +type connMsg struct { + buf buffer.View + addr tcpip.FullAddress + err error +} + +type dupConn struct { + c conn + dup chan connMsg +} + +func (c *dupConn) Read() (buffer.View, tcpip.FullAddress, error) { + v, addr, err := c.c.Read() + c.dup <- connMsg{v, addr, err} + return v, addr, err +} +func (c *dupConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) } + +type chConn struct { + ch chan connMsg + c conn +} + +func (c *chConn) Read() (buffer.View, tcpip.FullAddress, error) { + msg := <-c.ch + return msg.buf, msg.addr, msg.err +} +func (c *chConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) } + +func TestTwoServers(t *testing.T) { + s := createStack(t) + + wq := new(waiter.Queue) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("dhcp: server endpoint: %v", err) + } + if err = ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil { + t.Fatalf("dhcp: server bind: %v", err) + } + + serverCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + c1, c2 := teeConn(newEPConn(serverCtx, wq, ep)) + + if _, err := NewServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{ + ServerAddress: "\xc0\xa8\x03\x01", + SubnetMask: "\xff\xff\xff\x00", + Gateway: "\xc0\xa8\x03\xF0", + DNS: []tcpip.Address{"\x08\x08\x08\x08"}, + LeaseLength: 30 * time.Minute, + }); err != nil { + t.Fatal(err) + } + if _, err := NewServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{ + ServerAddress: "\xc0\xa8\x04\x01", + SubnetMask: "\xff\xff\xff\x00", + Gateway: "\xc0\xa8\x03\xF0", + DNS: []tcpip.Address{"\x08\x08\x08\x08"}, + LeaseLength: 30 * time.Minute, + }); err != nil { + t.Fatal(err) + } + + const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52") + c := NewClient(s, nicid, clientLinkAddr0, nil) + if _, err := c.Request(context.Background(), ""); err != nil { + t.Fatal(err) + } +} diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go index 0beac7782..003e272b2 100644 --- a/pkg/dhcp/server.go +++ b/pkg/dhcp/server.go @@ -17,11 +17,13 @@ package dhcp import ( "context" "fmt" + "io" "log" "sync" "time" "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" @@ -30,10 +32,8 @@ import ( // Server is a DHCP server. type Server struct { - stack *stack.Stack + conn conn broadcast tcpip.FullAddress - wq waiter.Queue - ep tcpip.Endpoint addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure cfg Config cfgopts []option // cfg to send to client @@ -44,36 +44,96 @@ type Server struct { leases map[tcpip.LinkAddress]serverLease } +// conn is a blocking read/write network endpoint. +type conn interface { + Read() (buffer.View, tcpip.FullAddress, error) + Write([]byte, *tcpip.FullAddress) error +} + +type epConn struct { + ctx context.Context + wq *waiter.Queue + ep tcpip.Endpoint + we waiter.Entry + inCh chan struct{} +} + +func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn { + c := &epConn{ + ctx: ctx, + wq: wq, + ep: ep, + } + c.we, c.inCh = waiter.NewChannelEntry(nil) + wq.EventRegister(&c.we, waiter.EventIn) + + go func() { + <-ctx.Done() + wq.EventUnregister(&c.we) + }() + + return c +} + +func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) { + for { + var addr tcpip.FullAddress + v, _, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-c.inCh: + continue + case <-c.ctx.Done(): + return nil, tcpip.FullAddress{}, io.EOF + } + } + if err != nil { + return v, addr, fmt.Errorf("read: %v", err) + } + return v, addr, nil + } +} + +func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error { + if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil { + return fmt.Errorf("write: %v", err) + } + return nil +} + +func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) { + wq := new(waiter.Queue) + ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, fmt.Errorf("dhcp: server endpoint: %v", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil { + return nil, fmt.Errorf("dhcp: server bind: %v", err) + } + c := newEPConn(ctx, wq, ep) + return NewServer(ctx, c, addrs, cfg) +} + // NewServer creates a new DHCP server and begins serving. // The server continues serving until ctx is done. -func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) { +func NewServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) { + if cfg.ServerAddress == "" { + return nil, fmt.Errorf("dhcp: server requires explicit server address") + } s := &Server{ - stack: stack, + conn: c, addrs: addrs, cfg: cfg, cfgopts: cfg.encode(), broadcast: tcpip.FullAddress{ Addr: "\xff\xff\xff\xff", - Port: clientPort, + Port: ClientPort, }, handlers: make([]chan header, 8), leases: make(map[tcpip.LinkAddress]serverLease), } - var err *tcpip.Error - s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq) - if err != nil { - return nil, fmt.Errorf("dhcp: server endpoint: %v", err) - } - serverBroadcast := tcpip.FullAddress{ - Addr: "", - Port: serverPort, - } - if err := s.ep.Bind(serverBroadcast, nil); err != nil { - return nil, fmt.Errorf("dhcp: server bind: %v", err) - } - for i := 0; i < len(s.handlers); i++ { ch := make(chan header, 8) s.handlers[i] = ch @@ -108,20 +168,10 @@ func (s *Server) expirer(ctx context.Context) { // reader listens for all incoming DHCP packets and fans them out to // handling goroutines based on XID as session identifiers. func (s *Server) reader(ctx context.Context) { - we, ch := waiter.NewChannelEntry(nil) - s.wq.EventRegister(&we, waiter.EventIn) - defer s.wq.EventUnregister(&we) - for { - var addr tcpip.FullAddress - v, _, err := s.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - select { - case <-ch: - continue - case <-ctx.Done(): - return - } + v, _, err := s.conn.Read() + if err != nil { + return } h := header(v) @@ -234,21 +284,50 @@ func (s *Server) handleDiscover(hreq header, opts options) { // DHCPOFFER opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}} opts = append(opts, s.cfgopts...) - h := make(header, headerBaseSize+opts.len()) + h := make(header, headerBaseSize+opts.len()+1) h.init() h.setOp(opReply) copy(h.xidbytes(), hreq.xidbytes()) copy(h.yiaddr(), lease.addr) - copy(h.siaddr(), s.cfg.ServerAddress) copy(h.chaddr(), hreq.chaddr()) h.setOptions(opts) - s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) + s.conn.Write([]byte(h), &s.broadcast) +} + +func (s *Server) nack(hreq header) { + // DHCPNACK + opts := options([]option{ + {optDHCPMsgType, []byte{byte(dhcpNAK)}}, + {optDHCPServer, []byte(s.cfg.ServerAddress)}, + }) + h := make(header, headerBaseSize+opts.len()+1) + h.init() + h.setOp(opReply) + copy(h.xidbytes(), hreq.xidbytes()) + copy(h.chaddr(), hreq.chaddr()) + h.setOptions(opts) + s.conn.Write([]byte(h), &s.broadcast) } func (s *Server) handleRequest(hreq header, opts options) { linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) xid := hreq.xid() + reqopts, err := hreq.options() + if err != nil { + s.nack(hreq) + return + } + var reqcfg Config + if err := reqcfg.decode(reqopts); err != nil { + s.nack(hreq) + return + } + if reqcfg.ServerAddress != s.cfg.ServerAddress { + // This request is for a different DHCP server. Ignore it. + return + } + s.mu.Lock() lease := s.leases[linkAddr] switch lease.state { @@ -271,15 +350,14 @@ func (s *Server) handleRequest(hreq header, opts options) { // DHCPACK opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}} opts = append(opts, s.cfgopts...) - h := make(header, headerBaseSize+opts.len()) + h := make(header, headerBaseSize+opts.len()+1) h.init() h.setOp(opReply) copy(h.xidbytes(), hreq.xidbytes()) copy(h.yiaddr(), lease.addr) - copy(h.siaddr(), s.cfg.ServerAddress) copy(h.chaddr(), hreq.chaddr()) h.setOptions(opts) - s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) + s.conn.Write([]byte(h), &s.broadcast) } type leaseState int |