diff options
author | Andrea Barberio <insomniac@slackware.it> | 2018-11-29 16:42:08 +0000 |
---|---|---|
committer | insomniac <insomniacslk@users.noreply.github.com> | 2018-11-29 18:03:09 +0000 |
commit | c1bfa9934849aa7934f14dd1a9e24780530855fc (patch) | |
tree | f5385d052509329232ea0fb8c0e2b81f64a489a4 | |
parent | 319e92b03a0b85eeaee17398b72ee759b2ccf905 (diff) |
interfaces: added package with interface facilities
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | dhcpv4/server_test.go | 23 | ||||
-rw-r--r-- | dhcpv6/server_test.go | 26 | ||||
-rw-r--r-- | interfaces/interfaces.go | 41 | ||||
-rw-r--r-- | interfaces/interfaces_test.go | 79 |
5 files changed, 130 insertions, 40 deletions
@@ -14,6 +14,7 @@ The library is split into several parts: * `iana`: several IANA constants, and helpers used by `dhcpv6` and `dhcpv4` * `rfc1035label`: simple implementation of RFC1035 labels, used by `dhcpv6` and `dhcpv4` +* `interfaces`, a thin layer of wrappers around network interfaces You will probably only need `dhcpv6` and/or `dhcpv4` explicitly. The rest is pulled in automatically if necessary. diff --git a/dhcpv4/server_test.go b/dhcpv4/server_test.go index ab426df..68bd694 100644 --- a/dhcpv4/server_test.go +++ b/dhcpv4/server_test.go @@ -3,13 +3,13 @@ package dhcpv4 import ( - "errors" "log" "math/rand" "net" "testing" "time" + "github.com/insomniacslk/dhcp/interfaces" "github.com/stretchr/testify/require" ) @@ -91,22 +91,6 @@ func setUpClientAndServer(handler Handler) (*Client, *Server) { return c, s } -// utility function to return the loopback interface name -// TODO this is copied from dhcpv6/server_test.go , we should refactor common code in a separate package -func getLoopbackInterface() (string, error) { - var ifaces []net.Interface - var err error - if ifaces, err = net.Interfaces(); err != nil { - return "", err - } - for _, iface := range ifaces { - if iface.Flags&net.FlagLoopback != 0 || iface.Name[:2] == "lo" { - return iface.Name, nil - } - } - return "", errors.New("No loopback interface found") -} - func TestNewServer(t *testing.T) { laddr := net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -125,8 +109,9 @@ func TestServerActivateAndServe(t *testing.T) { c, s := setUpClientAndServer(DORAHandler) defer s.Close() - lo, err := getLoopbackInterface() + ifaces, err := interfaces.GetLoopbackInterfaces() require.NoError(t, err) + require.NotEqual(t, 0, len(ifaces)) xid := uint32(0xaabbccdd) hwaddr := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} @@ -136,7 +121,7 @@ func TestServerActivateAndServe(t *testing.T) { WithHwAddr(hwaddr[:]), } - conv, err := c.Exchange(lo, modifiers...) + conv, err := c.Exchange(ifaces[0].Name, modifiers...) require.NoError(t, err) require.Equal(t, 4, len(conv)) for _, p := range conv { diff --git a/dhcpv6/server_test.go b/dhcpv6/server_test.go index 08cb507..436e808 100644 --- a/dhcpv6/server_test.go +++ b/dhcpv6/server_test.go @@ -5,28 +5,11 @@ import ( "net" "testing" "time" - "errors" + "github.com/insomniacslk/dhcp/interfaces" "github.com/stretchr/testify/require" ) -// utility function to return the loopback interface name -func getLoopbackInterface() (string, error) { - var ifaces []net.Interface - var err error - if ifaces, err = net.Interfaces(); err != nil { - return "", err - } - - for _, iface := range ifaces { - if iface.Flags & net.FlagLoopback != 0 || iface.Name[:2] == "lo" { - return iface.Name, nil - } - } - - return "", errors.New("No loopback interface found") -} - // utility function to set up a client and a server instance and run it in // background. The caller needs to call Server.Close() once finished. func setUpClientAndServer(handler Handler) (*Client, *Server) { @@ -39,7 +22,7 @@ func setUpClientAndServer(handler Handler) (*Client, *Server) { c := NewClient() c.LocalAddr = &net.UDPAddr{ - IP: net.ParseIP("::1"), + IP: net.ParseIP("::1"), } for { if s.LocalAddr() != nil { @@ -85,9 +68,10 @@ func TestServerActivateAndServe(t *testing.T) { c, s := setUpClientAndServer(handler) defer s.Close() - iface, err := getLoopbackInterface() + ifaces, err := interfaces.GetLoopbackInterfaces() require.NoError(t, err) + require.NotEqual(t, 0, len(ifaces)) - _, _, err = c.Solicit(iface) + _, _, err = c.Solicit(ifaces[0].Name) require.NoError(t, err) } diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go new file mode 100644 index 0000000..5761669 --- /dev/null +++ b/interfaces/interfaces.go @@ -0,0 +1,41 @@ +package interfaces + +import "net" + +// InterfaceMatcher is a function type used to match the interfaces we want. See +// GetInterfacesFunc below for usage. +type InterfaceMatcher func(net.Interface) bool + +// interfaceGetter is used for testing purposes +var interfaceGetter = net.Interfaces + +// GetInterfacesFunc loops through the available network interfaces, and returns +// a list of interfaces for which the passed InterfaceMatcher function returns +// true. +func GetInterfacesFunc(matcher InterfaceMatcher) ([]net.Interface, error) { + ifaces, err := interfaceGetter() + if err != nil { + return nil, err + } + ret := make([]net.Interface, 0) + for _, iface := range ifaces { + if matcher(iface) { + ret = append(ret, iface) + } + } + return ret, nil +} + +// GetLoopbackInterfaces returns a list of loopback interfaces. +func GetLoopbackInterfaces() ([]net.Interface, error) { + return GetInterfacesFunc(func(iface net.Interface) bool { + return iface.Flags&net.FlagLoopback != 0 + }) +} + +// GetNonLoopbackInterfaces returns a list of non-loopback interfaces. +func GetNonLoopbackInterfaces() ([]net.Interface, error) { + return GetInterfacesFunc(func(iface net.Interface) bool { + return iface.Flags&net.FlagLoopback == 0 + }) +} diff --git a/interfaces/interfaces_test.go b/interfaces/interfaces_test.go new file mode 100644 index 0000000..7633d5e --- /dev/null +++ b/interfaces/interfaces_test.go @@ -0,0 +1,79 @@ +package interfaces + +import ( + "errors" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func fakeIface(idx int, name string, loopback bool) net.Interface { + var flags net.Flags + if loopback { + flags |= net.FlagLoopback + } + return net.Interface{ + Index: idx, + MTU: 1500, + Name: name, + HardwareAddr: []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + Flags: flags, + } +} + +func TestGetLoopbackInterfaces(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return []net.Interface{ + fakeIface(0, "lo", true), + fakeIface(1, "eth0", false), + fakeIface(2, "eth1", false), + }, nil + } + ifaces, err := GetLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.NoError(t, err) + require.Equal(t, 1, len(ifaces)) +} + +func TestGetLoopbackInterfacesError(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return nil, errors.New("expected error") + + } + _, err := GetLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.Error(t, err) +} + +func TestGetNonLoopbackInterfaces(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return []net.Interface{ + fakeIface(0, "lo", true), + fakeIface(1, "eth0", false), + fakeIface(2, "eth1", false), + }, nil + } + ifaces, err := GetNonLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.NoError(t, err) + require.Equal(t, 2, len(ifaces)) +} + +func TestGetNonLoopbackInterfacesError(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return nil, errors.New("expected error") + + } + _, err := GetNonLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.Error(t, err) +} |