diff options
Diffstat (limited to 'dhcpv4/server4')
-rw-r--r-- | dhcpv4/server4/logger.go | 63 | ||||
-rw-r--r-- | dhcpv4/server4/logger_test.go | 42 | ||||
-rw-r--r-- | dhcpv4/server4/server.go | 40 | ||||
-rw-r--r-- | dhcpv4/server4/server_test.go | 115 |
4 files changed, 252 insertions, 8 deletions
diff --git a/dhcpv4/server4/logger.go b/dhcpv4/server4/logger.go new file mode 100644 index 0000000..91ed861 --- /dev/null +++ b/dhcpv4/server4/logger.go @@ -0,0 +1,63 @@ +package server4 + +import ( + "github.com/insomniacslk/dhcp/dhcpv4" +) + +// Logger is a handler which will be used to output logging messages +type Logger interface { + // PrintMessage print _all_ DHCP messages + PrintMessage(prefix string, message *dhcpv4.DHCPv4) + + // Printf is use to print the rest debugging information + Printf(format string, v ...interface{}) +} + +// EmptyLogger prints nothing +type EmptyLogger struct{} + +// Printf is just a dummy function that does nothing +func (e EmptyLogger) Printf(format string, v ...interface{}) {} + +// PrintMessage is just a dummy function that does nothing +func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {} + +// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer. +type Printfer interface { + // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf. + Printf(format string, v ...interface{}) +} + +// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the short format. +type ShortSummaryLogger struct { + // Printfer is used for actual output of the logger + Printfer +} + +// Printf prints a log message as-is via predefined Printfer +func (s ShortSummaryLogger) Printf(format string, v ...interface{}) { + s.Printfer.Printf(format, v...) +} + +// PrintMessage prints a DHCP message in the short format via predefined Printfer +func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + s.Printf("%s: %s", prefix, message) +} + +// DebugLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the long format. +type DebugLogger struct { + // Printfer is used for actual output of the logger + Printfer +} + +// Printf prints a log message as-is via predefined Printfer +func (d DebugLogger) Printf(format string, v ...interface{}) { + d.Printfer.Printf(format, v...) +} + +// PrintMessage prints a DHCP message in the long format via predefined Printfer +func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + d.Printf("%s: %s", prefix, message.Summary()) +} diff --git a/dhcpv4/server4/logger_test.go b/dhcpv4/server4/logger_test.go new file mode 100644 index 0000000..c33d83e --- /dev/null +++ b/dhcpv4/server4/logger_test.go @@ -0,0 +1,42 @@ +// +build go1.12 + +package server4 + +import( + "log" + "os" + "testing" + + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/stretchr/testify/require" +) + +func TestEmptyLogger(t *testing.T) { + l := EmptyLogger{} + msg, err := dhcpv4.New() + require.Nil(t, err) + l.Printf("test") + l.PrintMessage("prefix", msg) +} + +func TestShortSummaryLogger(t *testing.T) { + l := ShortSummaryLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + msg, err := dhcpv4.New() + require.Nil(t, err) + require.NotNil(t, msg) + l.Printf("test") + l.PrintMessage("prefix", msg) +} + +func TestDebugLogger(t *testing.T) { + l := DebugLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + msg, err := dhcpv4.New() + require.Nil(t, err) + require.NotNil(t, msg) + l.Printf("test") + l.PrintMessage("prefix", msg) +}
\ No newline at end of file diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go index 9c1cee2..c50e6a5 100644 --- a/dhcpv4/server4/server.go +++ b/dhcpv4/server4/server.go @@ -3,6 +3,7 @@ package server4 import ( "log" "net" + "os" "github.com/insomniacslk/dhcp/dhcpv4" ) @@ -68,32 +69,33 @@ type Handler func(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) type Server struct { conn net.PacketConn Handler Handler + logger Logger } // Serve serves requests. func (s *Server) Serve() error { - log.Printf("Server listening on %s", s.conn.LocalAddr()) - log.Print("Ready to handle requests") + s.logger.Printf("Server listening on %s", s.conn.LocalAddr()) + s.logger.Printf("Ready to handle requests") defer s.Close() for { rbuf := make([]byte, 4096) // FIXME this is bad n, peer, err := s.conn.ReadFrom(rbuf) if err != nil { - log.Printf("Error reading from packet conn: %v", err) + s.logger.Printf("Error reading from packet conn: %v", err) return err } - log.Printf("Handling request from %v", peer) + s.logger.Printf("Handling request from %v", peer) m, err := dhcpv4.FromBytes(rbuf[:n]) if err != nil { - log.Printf("Error parsing DHCPv4 request: %v", err) + s.logger.Printf("Error parsing DHCPv4 request: %v", err) continue } upeer, ok := peer.(*net.UDPAddr) if !ok { - log.Printf("Not a UDP connection? Peer is %s", peer) + s.logger.Printf("Not a UDP connection? Peer is %s", peer) continue } // Set peer to broadcast if the client did not have an IP. @@ -126,6 +128,7 @@ func WithConn(c net.PacketConn) ServerOpt { func NewServer(ifname string, addr *net.UDPAddr, handler Handler, opt ...ServerOpt) (*Server, error) { s := &Server{ Handler: handler, + logger: EmptyLogger{}, } for _, o := range opt { @@ -141,3 +144,28 @@ func NewServer(ifname string, addr *net.UDPAddr, handler Handler, opt ...ServerO } return s, nil } + +// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received. +func WithSummaryLogger() ServerOpt { + return func(s *Server) { + s.logger = ShortSummaryLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + } +} + +// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received. +func WithDebugLogger() ServerOpt { + return func(s *Server) { + s.logger = DebugLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + } +} + +// WithLogger set the logger (see interface Logger). +func WithLogger(newLogger Logger) ServerOpt { + return func(s *Server) { + s.logger = newLogger + } +} diff --git a/dhcpv4/server4/server_test.go b/dhcpv4/server4/server_test.go index 43314ad..9005e0e 100644 --- a/dhcpv4/server4/server_test.go +++ b/dhcpv4/server4/server_test.go @@ -7,6 +7,7 @@ import ( "log" "math/rand" "net" + "sync" "testing" "time" @@ -55,7 +56,7 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { // utility function to set up a client and a server instance and run it in // background. The caller needs to call Server.Close() once finished. -func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*nclient4.Client, *Server) { +func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler, logger *customLogger) (*nclient4.Client, *Server) { // strong assumption, I know loAddr := net.ParseIP("127.0.0.1") saddr := &net.UDPAddr{ @@ -70,6 +71,11 @@ func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (* if err != nil { t.Fatal(err) } + + if logger != nil { + s.logger = logger + } + go func() { _ = s.Serve() }() @@ -86,6 +92,30 @@ func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (* return c, s } +type customLogger struct { + tb testing.TB + called bool + mux sync.Mutex +} + +func (s *customLogger) Printf(format string, v ...interface{}) { + s.mux.Lock() + s.called = true + s.mux.Unlock() + s.tb.Logf("===CustomLogger BEGIN===") + s.tb.Logf(format, v...) + s.tb.Logf("===CustomLogger END===") +} + +func (s *customLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + s.mux.Lock() + s.called = true + s.mux.Unlock() + s.tb.Logf("===CustomLogger BEGIN===") + s.tb.Logf("%s: %s", prefix, message) + s.tb.Logf("===CustomLogger END===") +} + func TestServer(t *testing.T) { ifaces, err := interfaces.GetLoopbackInterfaces() require.NoError(t, err) @@ -96,7 +126,7 @@ func TestServer(t *testing.T) { hwaddr := net.HardwareAddr{1, 2, 3, 4, 5, 6} ifaces[0].HardwareAddr = hwaddr - c, s := setUpClientAndServer(t, ifaces[0], DORAHandler) + c, s := setUpClientAndServer(t, ifaces[0], DORAHandler, nil) defer func() { require.Nil(t, s.Close()) }() @@ -127,3 +157,84 @@ func TestBadAddrFamily(t *testing.T) { t.Fatal("Expected server4.NewServer to fail with an IPv6 address") } } + +func TestServerWithCustomLogger(t *testing.T) { + ifaces, err := interfaces.GetLoopbackInterfaces() + require.NoError(t, err) + require.NotEqual(t, 0, len(ifaces)) + + // lo has a HardwareAddr of "nil". The client will drop all packets + // that don't match the HWAddr of the client interface. + hwaddr := net.HardwareAddr{1, 2, 3, 4, 5, 6} + ifaces[0].HardwareAddr = hwaddr + + c, s := setUpClientAndServer(t, ifaces[0], DORAHandler, &customLogger{ + tb: t, + }) + defer func() { + require.Nil(t, s.Close()) + }() + + xid := dhcpv4.TransactionID{0xaa, 0xbb, 0xcc, 0xdd} + + modifiers := []dhcpv4.Modifier{ + dhcpv4.WithTransactionID(xid), + dhcpv4.WithHwAddr(ifaces[0].HardwareAddr), + } + + offer, ack, err := c.Request(context.Background(), modifiers...) + require.NoError(t, err) + require.NotNil(t, offer, ack) + for _, p := range []*dhcpv4.DHCPv4{offer, ack} { + require.Equal(t, xid, p.TransactionID) + require.Equal(t, ifaces[0].HardwareAddr, p.ClientHWAddr) + } + go func() { + time.Sleep(time.Second * 5) + require.Equal(t, true, s.logger.(*customLogger).called) + }() +} + +func TestServerInstantiationWithCustomLogger(t *testing.T) { + // strong assumption, I know + loAddr := net.ParseIP("127.0.0.1") + saddr := &net.UDPAddr{ + IP: loAddr, + Port: 0, + } + s, err := NewServer("", saddr, DORAHandler, WithLogger(&customLogger{ + tb: t, + })) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, s) +} + +func TestServerInstantiationWithSummaryLogger(t *testing.T) { + // strong assumption, I know + loAddr := net.ParseIP("127.0.0.1") + saddr := &net.UDPAddr{ + IP: loAddr, + Port: 0, + } + s, err := NewServer("", saddr, DORAHandler, WithSummaryLogger()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, s) +} + +func TestServerInstantiationWithDebugLogger(t *testing.T) { + // strong assumption, I know + loAddr := net.ParseIP("127.0.0.1") + saddr := &net.UDPAddr{ + IP: loAddr, + Port: 0, + } + s, err := NewServer("", saddr, DORAHandler, WithDebugLogger()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, s) +} |