summaryrefslogtreecommitdiffhomepage
path: root/interfaces
diff options
context:
space:
mode:
authorAndrea Barberio <insomniac@slackware.it>2018-11-29 16:42:08 +0000
committerinsomniac <insomniacslk@users.noreply.github.com>2018-11-29 18:03:09 +0000
commitc1bfa9934849aa7934f14dd1a9e24780530855fc (patch)
treef5385d052509329232ea0fb8c0e2b81f64a489a4 /interfaces
parent319e92b03a0b85eeaee17398b72ee759b2ccf905 (diff)
interfaces: added package with interface facilities
Diffstat (limited to 'interfaces')
-rw-r--r--interfaces/interfaces.go41
-rw-r--r--interfaces/interfaces_test.go79
2 files changed, 120 insertions, 0 deletions
diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go
new file mode 100644
index 0000000..5761669
--- /dev/null
+++ b/interfaces/interfaces.go
@@ -0,0 +1,41 @@
+package interfaces
+
+import "net"
+
+// InterfaceMatcher is a function type used to match the interfaces we want. See
+// GetInterfacesFunc below for usage.
+type InterfaceMatcher func(net.Interface) bool
+
+// interfaceGetter is used for testing purposes
+var interfaceGetter = net.Interfaces
+
+// GetInterfacesFunc loops through the available network interfaces, and returns
+// a list of interfaces for which the passed InterfaceMatcher function returns
+// true.
+func GetInterfacesFunc(matcher InterfaceMatcher) ([]net.Interface, error) {
+ ifaces, err := interfaceGetter()
+ if err != nil {
+ return nil, err
+ }
+ ret := make([]net.Interface, 0)
+ for _, iface := range ifaces {
+ if matcher(iface) {
+ ret = append(ret, iface)
+ }
+ }
+ return ret, nil
+}
+
+// GetLoopbackInterfaces returns a list of loopback interfaces.
+func GetLoopbackInterfaces() ([]net.Interface, error) {
+ return GetInterfacesFunc(func(iface net.Interface) bool {
+ return iface.Flags&net.FlagLoopback != 0
+ })
+}
+
+// GetNonLoopbackInterfaces returns a list of non-loopback interfaces.
+func GetNonLoopbackInterfaces() ([]net.Interface, error) {
+ return GetInterfacesFunc(func(iface net.Interface) bool {
+ return iface.Flags&net.FlagLoopback == 0
+ })
+}
diff --git a/interfaces/interfaces_test.go b/interfaces/interfaces_test.go
new file mode 100644
index 0000000..7633d5e
--- /dev/null
+++ b/interfaces/interfaces_test.go
@@ -0,0 +1,79 @@
+package interfaces
+
+import (
+ "errors"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func fakeIface(idx int, name string, loopback bool) net.Interface {
+ var flags net.Flags
+ if loopback {
+ flags |= net.FlagLoopback
+ }
+ return net.Interface{
+ Index: idx,
+ MTU: 1500,
+ Name: name,
+ HardwareAddr: []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
+ Flags: flags,
+ }
+}
+
+func TestGetLoopbackInterfaces(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return []net.Interface{
+ fakeIface(0, "lo", true),
+ fakeIface(1, "eth0", false),
+ fakeIface(2, "eth1", false),
+ }, nil
+ }
+ ifaces, err := GetLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.NoError(t, err)
+ require.Equal(t, 1, len(ifaces))
+}
+
+func TestGetLoopbackInterfacesError(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return nil, errors.New("expected error")
+
+ }
+ _, err := GetLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.Error(t, err)
+}
+
+func TestGetNonLoopbackInterfaces(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return []net.Interface{
+ fakeIface(0, "lo", true),
+ fakeIface(1, "eth0", false),
+ fakeIface(2, "eth1", false),
+ }, nil
+ }
+ ifaces, err := GetNonLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.NoError(t, err)
+ require.Equal(t, 2, len(ifaces))
+}
+
+func TestGetNonLoopbackInterfacesError(t *testing.T) {
+ interfaceGetter = func() ([]net.Interface, error) {
+ return nil, errors.New("expected error")
+
+ }
+ _, err := GetNonLoopbackInterfaces()
+ // this has to be reassigned before any require.* call
+ interfaceGetter = net.Interfaces
+
+ require.Error(t, err)
+}