diff options
-rw-r--r-- | dhcpv6/server.go | 53 | ||||
-rw-r--r-- | dhcpv6/server_test.go | 54 |
2 files changed, 98 insertions, 9 deletions
diff --git a/dhcpv6/server.go b/dhcpv6/server.go index b832a8a..15edd3f 100644 --- a/dhcpv6/server.go +++ b/dhcpv6/server.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net" + "time" ) /* @@ -44,8 +45,9 @@ func main() { } server := dhcpv6.NewServer(laddr, handler) + defer server.Close() if err := server.ActivateAndServe(); err != nil { - log.Fatal(err) + log.Panic(err) } } @@ -57,15 +59,25 @@ type Handler func(conn net.PacketConn, peer net.Addr, m DHCPv6) // Server represents a DHCPv6 server object type Server struct { - conn net.PacketConn - LocalAddr net.UDPAddr - Handler Handler + conn net.PacketConn + shouldStop bool + running bool + Handler Handler + localAddr net.UDPAddr +} + +func (s *Server) LocalAddr() net.Addr { + if s.conn == nil { + return nil + } + return s.conn.LocalAddr() } // ActivateAndServe starts the DHCPv6 server func (s *Server) ActivateAndServe() error { + s.shouldStop = false if s.conn == nil { - conn, err := net.ListenUDP("udp6", &s.LocalAddr) + conn, err := net.ListenUDP("udp6", &s.localAddr) if err != nil { return err } @@ -83,12 +95,23 @@ func (s *Server) ActivateAndServe() error { } log.Printf("Server listening on %s", pc.LocalAddr()) log.Print("Ready to handle requests") + s.running = true for { - log.Printf("Waiting..") + if s.shouldStop { + s.running = false + break + } + pc.SetReadDeadline(time.Now().Add(time.Second)) rbuf := make([]byte, 4096) // FIXME this is bad n, peer, err := pc.ReadFrom(rbuf) if err != nil { - log.Printf("Error reading from packet conn: %v", err) + switch err.(type) { + case net.Error: + // silently skip and continue + default: + //complain and continue + log.Printf("Error reading from packet conn: %v", err) + } continue } log.Printf("Handling request from %v", peer) @@ -103,10 +126,24 @@ func (s *Server) ActivateAndServe() error { return nil } +func (s *Server) Close() error { + s.shouldStop = true + for { + if !s.running { + break + } + time.Sleep(100 * time.Millisecond) + } + if s.conn != nil { + return s.conn.Close() + } + return nil +} + // NewServer initializes and returns a new Server object func NewServer(addr net.UDPAddr, handler Handler) *Server { return &Server{ - LocalAddr: addr, + localAddr: addr, Handler: handler, } } diff --git a/dhcpv6/server_test.go b/dhcpv6/server_test.go index 3794a4f..4fc919b 100644 --- a/dhcpv6/server_test.go +++ b/dhcpv6/server_test.go @@ -1,21 +1,73 @@ package dhcpv6 import ( + "log" "net" "testing" + "time" "github.com/stretchr/testify/require" ) +// utility function to set up a client and a server instance and run it in +// background. The caller needs to call Server.Close() once finished. +func setUpClientAndServer(handler Handler) (*Client, *Server) { + laddr := net.UDPAddr{ + IP: net.ParseIP("::1"), + Port: 0, + Zone: "lo", + } + s := NewServer(laddr, handler) + go s.ActivateAndServe() + + c := NewClient() + c.LocalAddr = &net.UDPAddr{ + IP: net.ParseIP("::1"), + Zone: "lo", + } + for { + if s.LocalAddr() != nil { + break + } + time.Sleep(10 * time.Millisecond) + log.Printf("Waiting for server to run...") + } + c.RemoteAddr = &net.UDPAddr{ + IP: net.ParseIP("::1"), + Port: s.LocalAddr().(*net.UDPAddr).Port, + Zone: "lo", + } + + return c, s +} + func TestNewServer(t *testing.T) { laddr := net.UDPAddr{ IP: net.ParseIP("::1"), Port: 0, + Zone: "lo", } handler := func(conn net.PacketConn, peer net.Addr, m DHCPv6) {} s := NewServer(laddr, handler) + defer s.Close() + require.NotNil(t, s) require.Nil(t, s.conn) - require.Equal(t, laddr, s.LocalAddr) + require.Equal(t, laddr, s.localAddr) require.NotNil(t, s.Handler) } + +func TestServerActivateAndServe(t *testing.T) { + handler := func(conn net.PacketConn, peer net.Addr, m DHCPv6) { + log.Printf("MESSAGE from %s, reply with %v", peer, m.ToBytes()) + if _, err := conn.WriteTo(m.ToBytes(), peer); err != nil { + log.Printf("Cannot reply to client: %v", err) + } + } + c, s := setUpClientAndServer(handler) + defer s.Close() + + _, _, err := c.Solicit("lo", nil) + + require.NoError(t, err) +} |