summaryrefslogtreecommitdiffhomepage
path: root/dhcpv4/server4
diff options
context:
space:
mode:
Diffstat (limited to 'dhcpv4/server4')
-rw-r--r--dhcpv4/server4/logger.go63
-rw-r--r--dhcpv4/server4/logger_test.go42
-rw-r--r--dhcpv4/server4/server.go40
-rw-r--r--dhcpv4/server4/server_test.go115
4 files changed, 252 insertions, 8 deletions
diff --git a/dhcpv4/server4/logger.go b/dhcpv4/server4/logger.go
new file mode 100644
index 0000000..91ed861
--- /dev/null
+++ b/dhcpv4/server4/logger.go
@@ -0,0 +1,63 @@
+package server4
+
+import (
+ "github.com/insomniacslk/dhcp/dhcpv4"
+)
+
+// Logger is a handler which will be used to output logging messages
+type Logger interface {
+ // PrintMessage print _all_ DHCP messages
+ PrintMessage(prefix string, message *dhcpv4.DHCPv4)
+
+ // Printf is use to print the rest debugging information
+ Printf(format string, v ...interface{})
+}
+
+// EmptyLogger prints nothing
+type EmptyLogger struct{}
+
+// Printf is just a dummy function that does nothing
+func (e EmptyLogger) Printf(format string, v ...interface{}) {}
+
+// PrintMessage is just a dummy function that does nothing
+func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {}
+
+// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer.
+type Printfer interface {
+ // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf.
+ Printf(format string, v ...interface{})
+}
+
+// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger.
+// DHCP messages are printed in the short format.
+type ShortSummaryLogger struct {
+ // Printfer is used for actual output of the logger
+ Printfer
+}
+
+// Printf prints a log message as-is via predefined Printfer
+func (s ShortSummaryLogger) Printf(format string, v ...interface{}) {
+ s.Printfer.Printf(format, v...)
+}
+
+// PrintMessage prints a DHCP message in the short format via predefined Printfer
+func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
+ s.Printf("%s: %s", prefix, message)
+}
+
+// DebugLogger is a wrapper for Printfer to implement interface Logger.
+// DHCP messages are printed in the long format.
+type DebugLogger struct {
+ // Printfer is used for actual output of the logger
+ Printfer
+}
+
+// Printf prints a log message as-is via predefined Printfer
+func (d DebugLogger) Printf(format string, v ...interface{}) {
+ d.Printfer.Printf(format, v...)
+}
+
+// PrintMessage prints a DHCP message in the long format via predefined Printfer
+func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
+ d.Printf("%s: %s", prefix, message.Summary())
+}
diff --git a/dhcpv4/server4/logger_test.go b/dhcpv4/server4/logger_test.go
new file mode 100644
index 0000000..c33d83e
--- /dev/null
+++ b/dhcpv4/server4/logger_test.go
@@ -0,0 +1,42 @@
+// +build go1.12
+
+package server4
+
+import(
+ "log"
+ "os"
+ "testing"
+
+ "github.com/insomniacslk/dhcp/dhcpv4"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEmptyLogger(t *testing.T) {
+ l := EmptyLogger{}
+ msg, err := dhcpv4.New()
+ require.Nil(t, err)
+ l.Printf("test")
+ l.PrintMessage("prefix", msg)
+}
+
+func TestShortSummaryLogger(t *testing.T) {
+ l := ShortSummaryLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ }
+ msg, err := dhcpv4.New()
+ require.Nil(t, err)
+ require.NotNil(t, msg)
+ l.Printf("test")
+ l.PrintMessage("prefix", msg)
+}
+
+func TestDebugLogger(t *testing.T) {
+ l := DebugLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ }
+ msg, err := dhcpv4.New()
+ require.Nil(t, err)
+ require.NotNil(t, msg)
+ l.Printf("test")
+ l.PrintMessage("prefix", msg)
+} \ No newline at end of file
diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go
index 9c1cee2..c50e6a5 100644
--- a/dhcpv4/server4/server.go
+++ b/dhcpv4/server4/server.go
@@ -3,6 +3,7 @@ package server4
import (
"log"
"net"
+ "os"
"github.com/insomniacslk/dhcp/dhcpv4"
)
@@ -68,32 +69,33 @@ type Handler func(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4)
type Server struct {
conn net.PacketConn
Handler Handler
+ logger Logger
}
// Serve serves requests.
func (s *Server) Serve() error {
- log.Printf("Server listening on %s", s.conn.LocalAddr())
- log.Print("Ready to handle requests")
+ s.logger.Printf("Server listening on %s", s.conn.LocalAddr())
+ s.logger.Printf("Ready to handle requests")
defer s.Close()
for {
rbuf := make([]byte, 4096) // FIXME this is bad
n, peer, err := s.conn.ReadFrom(rbuf)
if err != nil {
- log.Printf("Error reading from packet conn: %v", err)
+ s.logger.Printf("Error reading from packet conn: %v", err)
return err
}
- log.Printf("Handling request from %v", peer)
+ s.logger.Printf("Handling request from %v", peer)
m, err := dhcpv4.FromBytes(rbuf[:n])
if err != nil {
- log.Printf("Error parsing DHCPv4 request: %v", err)
+ s.logger.Printf("Error parsing DHCPv4 request: %v", err)
continue
}
upeer, ok := peer.(*net.UDPAddr)
if !ok {
- log.Printf("Not a UDP connection? Peer is %s", peer)
+ s.logger.Printf("Not a UDP connection? Peer is %s", peer)
continue
}
// Set peer to broadcast if the client did not have an IP.
@@ -126,6 +128,7 @@ func WithConn(c net.PacketConn) ServerOpt {
func NewServer(ifname string, addr *net.UDPAddr, handler Handler, opt ...ServerOpt) (*Server, error) {
s := &Server{
Handler: handler,
+ logger: EmptyLogger{},
}
for _, o := range opt {
@@ -141,3 +144,28 @@ func NewServer(ifname string, addr *net.UDPAddr, handler Handler, opt ...ServerO
}
return s, nil
}
+
+// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received.
+func WithSummaryLogger() ServerOpt {
+ return func(s *Server) {
+ s.logger = ShortSummaryLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ }
+ }
+}
+
+// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received.
+func WithDebugLogger() ServerOpt {
+ return func(s *Server) {
+ s.logger = DebugLogger{
+ Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
+ }
+ }
+}
+
+// WithLogger set the logger (see interface Logger).
+func WithLogger(newLogger Logger) ServerOpt {
+ return func(s *Server) {
+ s.logger = newLogger
+ }
+}
diff --git a/dhcpv4/server4/server_test.go b/dhcpv4/server4/server_test.go
index 43314ad..9005e0e 100644
--- a/dhcpv4/server4/server_test.go
+++ b/dhcpv4/server4/server_test.go
@@ -7,6 +7,7 @@ import (
"log"
"math/rand"
"net"
+ "sync"
"testing"
"time"
@@ -55,7 +56,7 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
// utility function to set up a client and a server instance and run it in
// background. The caller needs to call Server.Close() once finished.
-func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*nclient4.Client, *Server) {
+func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler, logger *customLogger) (*nclient4.Client, *Server) {
// strong assumption, I know
loAddr := net.ParseIP("127.0.0.1")
saddr := &net.UDPAddr{
@@ -70,6 +71,11 @@ func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*
if err != nil {
t.Fatal(err)
}
+
+ if logger != nil {
+ s.logger = logger
+ }
+
go func() {
_ = s.Serve()
}()
@@ -86,6 +92,30 @@ func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*
return c, s
}
+type customLogger struct {
+ tb testing.TB
+ called bool
+ mux sync.Mutex
+}
+
+func (s *customLogger) Printf(format string, v ...interface{}) {
+ s.mux.Lock()
+ s.called = true
+ s.mux.Unlock()
+ s.tb.Logf("===CustomLogger BEGIN===")
+ s.tb.Logf(format, v...)
+ s.tb.Logf("===CustomLogger END===")
+}
+
+func (s *customLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
+ s.mux.Lock()
+ s.called = true
+ s.mux.Unlock()
+ s.tb.Logf("===CustomLogger BEGIN===")
+ s.tb.Logf("%s: %s", prefix, message)
+ s.tb.Logf("===CustomLogger END===")
+}
+
func TestServer(t *testing.T) {
ifaces, err := interfaces.GetLoopbackInterfaces()
require.NoError(t, err)
@@ -96,7 +126,7 @@ func TestServer(t *testing.T) {
hwaddr := net.HardwareAddr{1, 2, 3, 4, 5, 6}
ifaces[0].HardwareAddr = hwaddr
- c, s := setUpClientAndServer(t, ifaces[0], DORAHandler)
+ c, s := setUpClientAndServer(t, ifaces[0], DORAHandler, nil)
defer func() {
require.Nil(t, s.Close())
}()
@@ -127,3 +157,84 @@ func TestBadAddrFamily(t *testing.T) {
t.Fatal("Expected server4.NewServer to fail with an IPv6 address")
}
}
+
+func TestServerWithCustomLogger(t *testing.T) {
+ ifaces, err := interfaces.GetLoopbackInterfaces()
+ require.NoError(t, err)
+ require.NotEqual(t, 0, len(ifaces))
+
+ // lo has a HardwareAddr of "nil". The client will drop all packets
+ // that don't match the HWAddr of the client interface.
+ hwaddr := net.HardwareAddr{1, 2, 3, 4, 5, 6}
+ ifaces[0].HardwareAddr = hwaddr
+
+ c, s := setUpClientAndServer(t, ifaces[0], DORAHandler, &customLogger{
+ tb: t,
+ })
+ defer func() {
+ require.Nil(t, s.Close())
+ }()
+
+ xid := dhcpv4.TransactionID{0xaa, 0xbb, 0xcc, 0xdd}
+
+ modifiers := []dhcpv4.Modifier{
+ dhcpv4.WithTransactionID(xid),
+ dhcpv4.WithHwAddr(ifaces[0].HardwareAddr),
+ }
+
+ offer, ack, err := c.Request(context.Background(), modifiers...)
+ require.NoError(t, err)
+ require.NotNil(t, offer, ack)
+ for _, p := range []*dhcpv4.DHCPv4{offer, ack} {
+ require.Equal(t, xid, p.TransactionID)
+ require.Equal(t, ifaces[0].HardwareAddr, p.ClientHWAddr)
+ }
+ go func() {
+ time.Sleep(time.Second * 5)
+ require.Equal(t, true, s.logger.(*customLogger).called)
+ }()
+}
+
+func TestServerInstantiationWithCustomLogger(t *testing.T) {
+ // strong assumption, I know
+ loAddr := net.ParseIP("127.0.0.1")
+ saddr := &net.UDPAddr{
+ IP: loAddr,
+ Port: 0,
+ }
+ s, err := NewServer("", saddr, DORAHandler, WithLogger(&customLogger{
+ tb: t,
+ }))
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.NotNil(t, s)
+}
+
+func TestServerInstantiationWithSummaryLogger(t *testing.T) {
+ // strong assumption, I know
+ loAddr := net.ParseIP("127.0.0.1")
+ saddr := &net.UDPAddr{
+ IP: loAddr,
+ Port: 0,
+ }
+ s, err := NewServer("", saddr, DORAHandler, WithSummaryLogger())
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.NotNil(t, s)
+}
+
+func TestServerInstantiationWithDebugLogger(t *testing.T) {
+ // strong assumption, I know
+ loAddr := net.ParseIP("127.0.0.1")
+ saddr := &net.UDPAddr{
+ IP: loAddr,
+ Port: 0,
+ }
+ s, err := NewServer("", saddr, DORAHandler, WithDebugLogger())
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.NotNil(t, s)
+}