summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrea Barberio <insomniac@slackware.it>2018-11-29 16:42:08 +0000
committerinsomniac <insomniacslk@users.noreply.github.com>2018-11-29 18:03:09 +0000
commitc1bfa9934849aa7934f14dd1a9e24780530855fc (patch)
treef5385d052509329232ea0fb8c0e2b81f64a489a4
parent319e92b03a0b85eeaee17398b72ee759b2ccf905 (diff)
interfaces: added package with interface facilities
-rw-r--r--README.md1
-rw-r--r--dhcpv4/server_test.go23
-rw-r--r--dhcpv6/server_test.go26
-rw-r--r--interfaces/interfaces.go41
-rw-r--r--interfaces/interfaces_test.go79
5 files changed, 130 insertions, 40 deletions
diff --git a/README.md b/README.md
index 9d852fc..425faa4 100644
--- a/README.md
+++ b/README.md
@@ -14,6 +14,7 @@ The library is split into several parts:
* `iana`: several IANA constants, and helpers used by `dhcpv6` and `dhcpv4`
* `rfc1035label`: simple implementation of RFC1035 labels, used by `dhcpv6` and
`dhcpv4`
+* `interfaces`, a thin layer of wrappers around network interfaces
You will probably only need `dhcpv6` and/or `dhcpv4` explicitly. The rest is
pulled in automatically if necessary.
diff --git a/dhcpv4/server_test.go b/dhcpv4/server_test.go
index ab426df..68bd694 100644
--- a/dhcpv4/server_test.go
+++ b/dhcpv4/server_test.go
@@ -3,13 +3,13 @@
package dhcpv4
import (
- "errors"
"log"
"math/rand"
"net"
"testing"
"time"
+ "github.com/insomniacslk/dhcp/interfaces"
"github.com/stretchr/testify/require"
)
@@ -91,22 +91,6 @@ func setUpClientAndServer(handler Handler) (*Client, *Server) {
return c, s
}
-// utility function to return the loopback interface name
-// TODO this is copied from dhcpv6/server_test.go , we should refactor common code in a separate package
-func getLoopbackInterface() (string, error) {
- var ifaces []net.Interface
- var err error
- if ifaces, err = net.Interfaces(); err != nil {
- return "", err
- }
- for _, iface := range ifaces {
- if iface.Flags&net.FlagLoopback != 0 || iface.Name[:2] == "lo" {
- return iface.Name, nil
- }
- }
- return "", errors.New("No loopback interface found")
-}
-
func TestNewServer(t *testing.T) {
laddr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
@@ -125,8 +109,9 @@ func TestServerActivateAndServe(t *testing.T) {
c, s := setUpClientAndServer(DORAHandler)
defer s.Close()
- lo, err := getLoopbackInterface()
+ ifaces, err := interfaces.GetLoopbackInterfaces()
require.NoError(t, err)
+ require.NotEqual(t, 0, len(ifaces))
xid := uint32(0xaabbccdd)
hwaddr := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
@@ -136,7 +121,7 @@ func TestServerActivateAndServe(t *testing.T) {
WithHwAddr(hwaddr[:]),
}
- conv, err := c.Exchange(lo, modifiers...)
+ conv, err := c.Exchange(ifaces[0].Name, modifiers...)
require.NoError(t, err)
require.Equal(t, 4, len(conv))
for _, p := range conv {
diff --git a/dhcpv6/server_test.go b/dhcpv6/server_test.go
index 08cb507..436e808 100644
--- a/dhcpv6/server_test.go
+++ b/dhcpv6/server_test.go
@@ -5,28 +5,11 @@ import (
"net"
"testing"
"time"
- "errors"
+ "github.com/insomniacslk/dhcp/interfaces"
"github.com/stretchr/testify/require"
)
-// utility function to return the loopback interface name
-func getLoopbackInterface() (string, error) {
- var ifaces []net.Interface
- var err error
- if ifaces, err = net.Interfaces(); err != nil {
- return "", err
- }
-
- for _, iface := range ifaces {
- if iface.Flags & net.FlagLoopback != 0 || iface.Name[:2] == "lo" {
- return iface.Name, nil
- }
- }
-
- return "", errors.New("No loopback interface found")
-}
-
// utility function to set up a client and a server instance and run it in
// background. The caller needs to call Server.Close() once finished.
func setUpClientAndServer(handler Handler) (*Client, *Server) {
@@ -39,7 +22,7 @@ func setUpClientAndServer(handler Handler) (*Client, *Server) {
c := NewClient()
c.LocalAddr = &net.UDPAddr{
- IP: net.ParseIP("::1"),
+ IP: net.ParseIP("::1"),
}
for {
if s.LocalAddr() != nil {
@@ -85,9 +68,10 @@ func TestServerActivateAndServe(t *testing.T) {
c, s := setUpClientAndServer(handler)
defer s.Close()
- iface, err := getLoopbackInterface()
+ ifaces, err := interfaces.GetLoopbackInterfaces()
require.NoError(t, err)
+ require.NotEqual(t, 0, len(ifaces))
- _, _, err = c.Solicit(iface)
+ _, _, err = c.Solicit(ifaces[0].Name)
require.NoError(t, err)
}
diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go
new file mode 100644
index 0000000..5761669
--- /dev/null
+++ b/interfaces/interfaces.go
@@ -0,0 +1,41 @@
+package interfaces
+
+import "net"
+
+// InterfaceMatcher is a function type used to match the interfaces we want. See
+// GetInterfacesFunc below for usage.
+type InterfaceMatcher func(net.Interface) bool
+
+// interfaceGetter is used for testing purposes
+var interfaceGetter = net.Interfaces
+
+// GetInterfacesFunc loops through the available network interfaces, and returns
+// a list of interfaces for which the passed InterfaceMatcher function returns
+// true.
+func GetInterfacesFunc(matcher InterfaceMatcher) ([]net.Interface, error) {
+ ifaces, err := interfaceGetter()
+ if err != nil {
+ return nil, err
+ }
+ ret := make([]net.Interface, 0)
+ for _, iface := range ifaces {
+ if matcher(iface) {
+ ret = append(ret, iface)
+ }
+ }
+ return ret, nil
+}
+
+// GetLoopbackInterfaces returns a list of loopback interfaces.
+func GetLoopbackInterfaces() ([]net.Interface, error) {
+ return GetInterfacesFunc(func(iface net.Interface) bool {
+ return iface.Flags&net.FlagLoopback != 0
+ })
+}
+
+// GetNonLoopbackInterfaces returns a list of non-loopback interfaces.
+func GetNonLoopbackInterfaces() ([]net.Interface, error) {
+ return GetInterfacesFunc(func(iface net.Interface) bool {
+ return iface.Flags&net.FlagLoopback == 0
+ })
+}
diff --git a/interfaces/interfaces_test.go b/interfaces/interfaces_test.go
new file mode 100644
index 0000000..7633d5e
--- /dev/null
+++ b/interfaces/interfaces_test.go
@@ -0,0 +1,79 @@
+package interfaces
+
+import (
+ "errors"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func fakeIface(idx int, name string, loopback bool) net.Interface {
+ var flags net.Flags
+ if loopback {
+ flags |= net.FlagLoopback
+ }
+ return net.Interface{
+ Index: idx,
+ MTU: 1500,
+ Name: name,
+ HardwareAddr: []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
+ Flags: flags,
+ }
+}
+
+func TestGetLoopbackInterfaces(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return []net.Interface{
+ fakeIface(0, "lo", true),
+ fakeIface(1, "eth0", false),
+ fakeIface(2, "eth1", false),
+ }, nil
+ }
+ ifaces, err := GetLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.NoError(t, err)
+ require.Equal(t, 1, len(ifaces))
+}
+
+func TestGetLoopbackInterfacesError(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return nil, errors.New("expected error")
+
+ }
+ _, err := GetLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.Error(t, err)
+}
+
+func TestGetNonLoopbackInterfaces(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return []net.Interface{
+ fakeIface(0, "lo", true),
+ fakeIface(1, "eth0", false),
+ fakeIface(2, "eth1", false),
+ }, nil
+ }
+ ifaces, err := GetNonLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.NoError(t, err)
+ require.Equal(t, 2, len(ifaces))
+}
+
+func TestGetNonLoopbackInterfacesError(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return nil, errors.New("expected error")
+
+ }
+ _, err := GetNonLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.Error(t, err)
+}