summaryrefslogtreecommitdiffhomepage
path: root/pkg/dhcp/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/dhcp/server.go')
-rw-r--r--pkg/dhcp/server.go154
1 files changed, 116 insertions, 38 deletions
diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go
index 0beac7782..003e272b2 100644
--- a/pkg/dhcp/server.go
+++ b/pkg/dhcp/server.go
@@ -17,11 +17,13 @@ package dhcp
import (
"context"
"fmt"
+ "io"
"log"
"sync"
"time"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
@@ -30,10 +32,8 @@ import (
// Server is a DHCP server.
type Server struct {
- stack *stack.Stack
+ conn conn
broadcast tcpip.FullAddress
- wq waiter.Queue
- ep tcpip.Endpoint
addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
cfg Config
cfgopts []option // cfg to send to client
@@ -44,36 +44,96 @@ type Server struct {
leases map[tcpip.LinkAddress]serverLease
}
+// conn is a blocking read/write network endpoint.
+type conn interface {
+ Read() (buffer.View, tcpip.FullAddress, error)
+ Write([]byte, *tcpip.FullAddress) error
+}
+
+type epConn struct {
+ ctx context.Context
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+ we waiter.Entry
+ inCh chan struct{}
+}
+
+func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn {
+ c := &epConn{
+ ctx: ctx,
+ wq: wq,
+ ep: ep,
+ }
+ c.we, c.inCh = waiter.NewChannelEntry(nil)
+ wq.EventRegister(&c.we, waiter.EventIn)
+
+ go func() {
+ <-ctx.Done()
+ wq.EventUnregister(&c.we)
+ }()
+
+ return c
+}
+
+func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ for {
+ var addr tcpip.FullAddress
+ v, _, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-c.inCh:
+ continue
+ case <-c.ctx.Done():
+ return nil, tcpip.FullAddress{}, io.EOF
+ }
+ }
+ if err != nil {
+ return v, addr, fmt.Errorf("read: %v", err)
+ }
+ return v, addr, nil
+ }
+}
+
+func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
+ if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
+ return fmt.Errorf("write: %v", err)
+ }
+ return nil
+}
+
+func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ wq := new(waiter.Queue)
+ ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
+ return nil, fmt.Errorf("dhcp: server bind: %v", err)
+ }
+ c := newEPConn(ctx, wq, ep)
+ return NewServer(ctx, c, addrs, cfg)
+}
+
// NewServer creates a new DHCP server and begins serving.
// The server continues serving until ctx is done.
-func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+func NewServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ if cfg.ServerAddress == "" {
+ return nil, fmt.Errorf("dhcp: server requires explicit server address")
+ }
s := &Server{
- stack: stack,
+ conn: c,
addrs: addrs,
cfg: cfg,
cfgopts: cfg.encode(),
broadcast: tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
- Port: clientPort,
+ Port: ClientPort,
},
handlers: make([]chan header, 8),
leases: make(map[tcpip.LinkAddress]serverLease),
}
- var err *tcpip.Error
- s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq)
- if err != nil {
- return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
- }
- serverBroadcast := tcpip.FullAddress{
- Addr: "",
- Port: serverPort,
- }
- if err := s.ep.Bind(serverBroadcast, nil); err != nil {
- return nil, fmt.Errorf("dhcp: server bind: %v", err)
- }
-
for i := 0; i < len(s.handlers); i++ {
ch := make(chan header, 8)
s.handlers[i] = ch
@@ -108,20 +168,10 @@ func (s *Server) expirer(ctx context.Context) {
// reader listens for all incoming DHCP packets and fans them out to
// handling goroutines based on XID as session identifiers.
func (s *Server) reader(ctx context.Context) {
- we, ch := waiter.NewChannelEntry(nil)
- s.wq.EventRegister(&we, waiter.EventIn)
- defer s.wq.EventUnregister(&we)
-
for {
- var addr tcpip.FullAddress
- v, _, err := s.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- select {
- case <-ch:
- continue
- case <-ctx.Done():
- return
- }
+ v, _, err := s.conn.Read()
+ if err != nil {
+ return
}
h := header(v)
@@ -234,21 +284,50 @@ func (s *Server) handleDiscover(hreq header, opts options) {
// DHCPOFFER
opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}}
opts = append(opts, s.cfgopts...)
- h := make(header, headerBaseSize+opts.len())
+ h := make(header, headerBaseSize+opts.len()+1)
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+ s.conn.Write([]byte(h), &s.broadcast)
+}
+
+func (s *Server) nack(hreq header) {
+ // DHCPNACK
+ opts := options([]option{
+ {optDHCPMsgType, []byte{byte(dhcpNAK)}},
+ {optDHCPServer, []byte(s.cfg.ServerAddress)},
+ })
+ h := make(header, headerBaseSize+opts.len()+1)
+ h.init()
+ h.setOp(opReply)
+ copy(h.xidbytes(), hreq.xidbytes())
+ copy(h.chaddr(), hreq.chaddr())
+ h.setOptions(opts)
+ s.conn.Write([]byte(h), &s.broadcast)
}
func (s *Server) handleRequest(hreq header, opts options) {
linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
xid := hreq.xid()
+ reqopts, err := hreq.options()
+ if err != nil {
+ s.nack(hreq)
+ return
+ }
+ var reqcfg Config
+ if err := reqcfg.decode(reqopts); err != nil {
+ s.nack(hreq)
+ return
+ }
+ if reqcfg.ServerAddress != s.cfg.ServerAddress {
+ // This request is for a different DHCP server. Ignore it.
+ return
+ }
+
s.mu.Lock()
lease := s.leases[linkAddr]
switch lease.state {
@@ -271,15 +350,14 @@ func (s *Server) handleRequest(hreq header, opts options) {
// DHCPACK
opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}}
opts = append(opts, s.cfgopts...)
- h := make(header, headerBaseSize+opts.len())
+ h := make(header, headerBaseSize+opts.len()+1)
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast})
+ s.conn.Write([]byte(h), &s.broadcast)
}
type leaseState int