diff options
Diffstat (limited to 'pkg/dhcp/server.go')
-rw-r--r-- | pkg/dhcp/server.go | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go new file mode 100644 index 000000000..d132d90b4 --- /dev/null +++ b/pkg/dhcp/server.go @@ -0,0 +1,289 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dhcp + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Server is a DHCP server. +type Server struct { + stack *stack.Stack + broadcast tcpip.FullAddress + wq waiter.Queue + ep tcpip.Endpoint + addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure + cfg Config + cfgopts []option // cfg to send to client + + handlers []chan header + + mu sync.Mutex + leases map[tcpip.LinkAddress]serverLease +} + +// NewServer creates a new DHCP server and begins serving. +// The server continues serving until ctx is done. +func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) { + s := &Server{ + stack: stack, + addrs: addrs, + cfg: cfg, + cfgopts: cfg.encode(), + broadcast: tcpip.FullAddress{ + Addr: "\xff\xff\xff\xff", + Port: clientPort, + }, + + handlers: make([]chan header, 8), + leases: make(map[tcpip.LinkAddress]serverLease), + } + + var err *tcpip.Error + s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq) + if err != nil { + return nil, fmt.Errorf("dhcp: server endpoint: %v", err) + } + serverBroadcast := tcpip.FullAddress{ + Addr: "", + Port: serverPort, + } + if err := s.ep.Bind(serverBroadcast, nil); err != nil { + return nil, fmt.Errorf("dhcp: server bind: %v", err) + } + + for i := 0; i < len(s.handlers); i++ { + ch := make(chan header, 8) + s.handlers[i] = ch + go s.handler(ctx, ch) + } + + go s.expirer(ctx) + go s.reader(ctx) + return s, nil +} + +func (s *Server) expirer(ctx context.Context) { + t := time.NewTicker(1 * time.Minute) + defer t.Stop() + for { + select { + case <-t.C: + s.mu.Lock() + for linkAddr, lease := range s.leases { + if time.Since(lease.start) > s.cfg.LeaseLength { + lease.state = leaseExpired + s.leases[linkAddr] = lease + } + } + s.mu.Unlock() + case <-ctx.Done(): + return + } + } +} + +// reader listens for all incoming DHCP packets and fans them out to +// handling goroutines based on XID as session identifiers. +func (s *Server) reader(ctx context.Context) { + we, ch := waiter.NewChannelEntry(nil) + s.wq.EventRegister(&we, waiter.EventIn) + defer s.wq.EventUnregister(&we) + + for { + var addr tcpip.FullAddress + v, err := s.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + select { + case <-ch: + continue + case <-ctx.Done(): + return + } + } + + h := header(v) + if !h.isValid() || h.op() != opRequest { + continue + } + xid := h.xid() + + // Fan out the packet to a handler goroutine. + // + // Use a consistent handler for a given xid, so that + // packets from a particular client are processed + // in order. + ch := s.handlers[int(xid)%len(s.handlers)] + select { + case <-ctx.Done(): + return + case ch <- h: + default: + // drop the packet + } + } +} + +func (s *Server) handler(ctx context.Context, ch chan header) { + for { + select { + case h := <-ch: + if h == nil { + return + } + opts, err := h.options() + if err != nil { + continue + } + // TODO: Handle DHCPRELEASE and DHCPDECLINE. + msgtype, err := opts.dhcpMsgType() + if err != nil { + continue + } + switch msgtype { + case dhcpDISCOVER: + s.handleDiscover(h, opts) + case dhcpREQUEST: + s.handleRequest(h, opts) + } + case <-ctx.Done(): + return + } + } +} + +func (s *Server) handleDiscover(hreq header, opts options) { + linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) + xid := hreq.xid() + + s.mu.Lock() + lease := s.leases[linkAddr] + switch lease.state { + case leaseNew: + if len(s.leases) < len(s.addrs) { + // Find an unused address. + // TODO: avoid building this state on each request. + alloced := make(map[tcpip.Address]bool) + for _, lease := range s.leases { + alloced[lease.addr] = true + } + for _, addr := range s.addrs { + if !alloced[addr] { + lease = serverLease{ + start: time.Now(), + addr: addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + break + } + } + } else { + // No more addresses, take an expired address. + for k, oldLease := range s.leases { + if oldLease.state == leaseExpired { + delete(s.leases, k) + lease = serverLease{ + start: time.Now(), + addr: lease.addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + break + } + } + log.Printf("server has no more addresses") + s.mu.Unlock() + return + } + case leaseOffer, leaseAck, leaseExpired: + lease = serverLease{ + start: time.Now(), + addr: s.leases[linkAddr].addr, + xid: xid, + state: leaseOffer, + } + s.leases[linkAddr] = lease + } + s.mu.Unlock() + + // DHCPOFFER + opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}} + opts = append(opts, s.cfgopts...) + h := make(header, headerBaseSize+opts.len()) + h.init() + h.setOp(opReply) + copy(h.xidbytes(), hreq.xidbytes()) + copy(h.yiaddr(), lease.addr) + copy(h.siaddr(), s.cfg.ServerAddress) + copy(h.chaddr(), hreq.chaddr()) + h.setOptions(opts) + s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) +} + +func (s *Server) handleRequest(hreq header, opts options) { + linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6]) + xid := hreq.xid() + + s.mu.Lock() + lease := s.leases[linkAddr] + switch lease.state { + case leaseOffer, leaseAck, leaseExpired: + lease = serverLease{ + start: time.Now(), + addr: s.leases[linkAddr].addr, + xid: xid, + state: leaseAck, + } + s.leases[linkAddr] = lease + } + s.mu.Unlock() + + if lease.state == leaseNew { + // TODO: NACK or accept request + return + } + + // DHCPACK + opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}} + opts = append(opts, s.cfgopts...) + h := make(header, headerBaseSize+opts.len()) + h.init() + h.setOp(opReply) + copy(h.xidbytes(), hreq.xidbytes()) + copy(h.yiaddr(), lease.addr) + copy(h.siaddr(), s.cfg.ServerAddress) + copy(h.chaddr(), hreq.chaddr()) + h.setOptions(opts) + s.ep.Write(tcpip.SlicePayload(h), tcpip.WriteOptions{To: &s.broadcast}) +} + +type leaseState int + +const ( + leaseNew leaseState = iota + leaseOffer + leaseAck + leaseExpired +) + +type serverLease struct { + start time.Time + addr tcpip.Address + xid xid + state leaseState +} |