diff options
author | Andrea Barberio <insomniac@slackware.it> | 2018-11-29 16:42:08 +0000 |
---|---|---|
committer | insomniac <insomniacslk@users.noreply.github.com> | 2018-11-29 18:03:09 +0000 |
commit | c1bfa9934849aa7934f14dd1a9e24780530855fc (patch) | |
tree | f5385d052509329232ea0fb8c0e2b81f64a489a4 /interfaces | |
parent | 319e92b03a0b85eeaee17398b72ee759b2ccf905 (diff) |
interfaces: added package with interface facilities
Diffstat (limited to 'interfaces')
-rw-r--r-- | interfaces/interfaces.go | 41 | ||||
-rw-r--r-- | interfaces/interfaces_test.go | 79 |
2 files changed, 120 insertions, 0 deletions
diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go new file mode 100644 index 0000000..5761669 --- /dev/null +++ b/interfaces/interfaces.go @@ -0,0 +1,41 @@ +package interfaces + +import "net" + +// InterfaceMatcher is a function type used to match the interfaces we want. See +// GetInterfacesFunc below for usage. +type InterfaceMatcher func(net.Interface) bool + +// interfaceGetter is used for testing purposes +var interfaceGetter = net.Interfaces + +// GetInterfacesFunc loops through the available network interfaces, and returns +// a list of interfaces for which the passed InterfaceMatcher function returns +// true. +func GetInterfacesFunc(matcher InterfaceMatcher) ([]net.Interface, error) { + ifaces, err := interfaceGetter() + if err != nil { + return nil, err + } + ret := make([]net.Interface, 0) + for _, iface := range ifaces { + if matcher(iface) { + ret = append(ret, iface) + } + } + return ret, nil +} + +// GetLoopbackInterfaces returns a list of loopback interfaces. +func GetLoopbackInterfaces() ([]net.Interface, error) { + return GetInterfacesFunc(func(iface net.Interface) bool { + return iface.Flags&net.FlagLoopback != 0 + }) +} + +// GetNonLoopbackInterfaces returns a list of non-loopback interfaces. +func GetNonLoopbackInterfaces() ([]net.Interface, error) { + return GetInterfacesFunc(func(iface net.Interface) bool { + return iface.Flags&net.FlagLoopback == 0 + }) +} diff --git a/interfaces/interfaces_test.go b/interfaces/interfaces_test.go new file mode 100644 index 0000000..7633d5e --- /dev/null +++ b/interfaces/interfaces_test.go @@ -0,0 +1,79 @@ +package interfaces + +import ( + "errors" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func fakeIface(idx int, name string, loopback bool) net.Interface { + var flags net.Flags + if loopback { + flags |= net.FlagLoopback + } + return net.Interface{ + Index: idx, + MTU: 1500, + Name: name, + HardwareAddr: []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + Flags: flags, + } +} + +func TestGetLoopbackInterfaces(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return []net.Interface{ + fakeIface(0, "lo", true), + fakeIface(1, "eth0", false), + fakeIface(2, "eth1", false), + }, nil + } + ifaces, err := GetLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.NoError(t, err) + require.Equal(t, 1, len(ifaces)) +} + +func TestGetLoopbackInterfacesError(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return nil, errors.New("expected error") + + } + _, err := GetLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.Error(t, err) +} + +func TestGetNonLoopbackInterfaces(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return []net.Interface{ + fakeIface(0, "lo", true), + fakeIface(1, "eth0", false), + fakeIface(2, "eth1", false), + }, nil + } + ifaces, err := GetNonLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.NoError(t, err) + require.Equal(t, 2, len(ifaces)) +} + +func TestGetNonLoopbackInterfacesError(t *testing.T) { + interfaceGetter = func() ([]net.Interface, error) { + return nil, errors.New("expected error") + + } + _, err := GetNonLoopbackInterfaces() + // this has to be reassigned before any require.* call + interfaceGetter = net.Interfaces + + require.Error(t, err) +} |