diff options
-rw-r--r-- | dhcpv4/server4/conn.go | 15 | ||||
-rw-r--r-- | dhcpv4/server4/server.go | 2 | ||||
-rw-r--r-- | dhcpv4/server4/server_test.go | 31 |
3 files changed, 31 insertions, 17 deletions
diff --git a/dhcpv4/server4/conn.go b/dhcpv4/server4/conn.go index 0a4c73a..3e49669 100644 --- a/dhcpv4/server4/conn.go +++ b/dhcpv4/server4/conn.go @@ -14,7 +14,7 @@ import ( // given based on a IPv4 DGRAM socket. The UDP connection allows broadcasting. // // The interface must already be configured. -func NewIPv4UDPConn(iface string, port int) (*net.UDPConn, error) { +func NewIPv4UDPConn(iface string, addr *net.UDPAddr) (*net.UDPConn, error) { fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err != nil { return nil, fmt.Errorf("cannot get a UDP socket: %v", err) @@ -37,9 +37,18 @@ func NewIPv4UDPConn(iface string, port int) (*net.UDPConn, error) { return nil, fmt.Errorf("cannot bind to interface %s: %v", iface, err) } } + + if addr == nil { + addr = &net.UDPAddr{Port: dhcpv4.ServerPort} + } // Bind to the port. - if err := unix.Bind(fd, &unix.SockaddrInet4{Port: port}); err != nil { - return nil, fmt.Errorf("cannot bind to port %d: %v", port, err) + saddr := unix.SockaddrInet4{Port: addr.Port} + if addr.IP != nil && addr.IP.To4() == nil { + return nil, fmt.Errorf("wrong address family (expected v4) for %s", addr.IP) + } + copy(saddr.Addr[:], addr.IP.To4()) + if err := unix.Bind(fd, &saddr); err != nil { + return nil, fmt.Errorf("cannot bind to port %d: %v", addr.Port, err) } conn, err := net.FilePacketConn(f) diff --git a/dhcpv4/server4/server.go b/dhcpv4/server4/server.go index fe4ef09..8bf7924 100644 --- a/dhcpv4/server4/server.go +++ b/dhcpv4/server4/server.go @@ -133,7 +133,7 @@ func NewServer(ifname string, addr *net.UDPAddr, handler Handler, opt ...ServerO } if s.conn == nil { var err error - conn, err := NewIPv4UDPConn(ifname, addr.Port) + conn, err := NewIPv4UDPConn(ifname, addr) if err != nil { return nil, err } diff --git a/dhcpv4/server4/server_test.go b/dhcpv4/server4/server_test.go index a596d04..43314ad 100644 --- a/dhcpv4/server4/server_test.go +++ b/dhcpv4/server4/server_test.go @@ -22,13 +22,6 @@ func init() { rand.Seed(time.Now().UTC().UnixNano()) } -func randPort() int { - // can't use port 0 with raw sockets, so until we implement - // a non-raw-sockets client for non-static ports, we have to - // deal with this "randomness" - return 32*1024 + rand.Intn(65536-32*1024) -} - // DORAHandler is a server handler suitable for DORA transactions func DORAHandler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { if m == nil { @@ -65,27 +58,28 @@ func DORAHandler(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { func setUpClientAndServer(t *testing.T, iface net.Interface, handler Handler) (*nclient4.Client, *Server) { // strong assumption, I know loAddr := net.ParseIP("127.0.0.1") - saddr := net.UDPAddr{ + saddr := &net.UDPAddr{ IP: loAddr, - Port: randPort(), + Port: 0, } caddr := net.UDPAddr{ IP: loAddr, - Port: randPort(), + Port: 0, } - s, err := NewServer("", &saddr, handler) + s, err := NewServer("", saddr, handler) if err != nil { t.Fatal(err) } go func() { _ = s.Serve() }() + saddr = s.conn.LocalAddr().(*net.UDPAddr) - clientConn, err := NewIPv4UDPConn("", caddr.Port) + clientConn, err := NewIPv4UDPConn("", &caddr) if err != nil { t.Fatal(err) } - c, err := nclient4.NewWithConn(clientConn, iface.HardwareAddr, nclient4.WithServerAddr(&saddr)) + c, err := nclient4.NewWithConn(clientConn, iface.HardwareAddr, nclient4.WithServerAddr(saddr)) if err != nil { t.Fatal(err) } @@ -122,3 +116,14 @@ func TestServer(t *testing.T) { require.Equal(t, ifaces[0].HardwareAddr, p.ClientHWAddr) } } + +func TestBadAddrFamily(t *testing.T) { + saddr := &net.UDPAddr{ + IP: net.IPv6loopback, + Port: 0, + } + _, err := NewServer("", saddr, DORAHandler) + if err == nil { + t.Fatal("Expected server4.NewServer to fail with an IPv6 address") + } +} |