diff options
-rw-r--r-- | dhcpv6/async/client.go | 2 | ||||
-rw-r--r-- | dhcpv6/client.go | 2 | ||||
-rw-r--r-- | dhcpv6/iputils.go | 48 | ||||
-rw-r--r-- | dhcpv6/iputils_test.go | 103 |
4 files changed, 136 insertions, 19 deletions
diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go index 08c2cfb..c574208 100644 --- a/dhcpv6/async/client.go +++ b/dhcpv6/async/client.go @@ -45,7 +45,7 @@ func (c *Client) OpenForInterface(ifname string, bufferSize int) error { if err != nil { return err } - c.LocalAddr = &net.UDPAddr{IP: *addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} + c.LocalAddr = &net.UDPAddr{IP: addr, Port: dhcpv6.DefaultClientPort, Zone: ifname} return c.Open(bufferSize) } diff --git a/dhcpv6/client.go b/dhcpv6/client.go index 3ed7861..10a20c9 100644 --- a/dhcpv6/client.go +++ b/dhcpv6/client.go @@ -100,7 +100,7 @@ func (c *Client) sendReceive(ifname string, packet DHCPv6, expectedType MessageT if err != nil { return nil, err } - laddr = net.UDPAddr{IP: *llAddr, Port: DefaultClientPort, Zone: ifname} + laddr = net.UDPAddr{IP: llAddr, Port: DefaultClientPort, Zone: ifname} } else { if addr, ok := c.LocalAddr.(*net.UDPAddr); ok { laddr = *addr diff --git a/dhcpv6/iputils.go b/dhcpv6/iputils.go index c3ac3aa..2b9788c 100644 --- a/dhcpv6/iputils.go +++ b/dhcpv6/iputils.go @@ -5,26 +5,40 @@ import ( "net" ) -func GetLinkLocalAddr(ifname string) (*net.IP, error) { - ifaces, err := net.Interfaces() +// InterfaceAddresses is used to fetch addresses of an interface with given name +var InterfaceAddresses func(string) ([]net.Addr, error) = interfaceAddresses + +func interfaceAddresses(ifname string) ([]net.Addr, error) { + iface, err := net.InterfaceByName(ifname) if err != nil { return nil, err } - for _, iface := range ifaces { - if iface.Name != ifname { - continue - } - ifaddrs, err := iface.Addrs() - if err != nil { - return nil, err - } - for _, ifaddr := range ifaddrs { - if ifaddr, ok := ifaddr.(*net.IPNet); ok { - if ifaddr.IP.To4() == nil && ifaddr.IP.IsLinkLocalUnicast() { - return &ifaddr.IP, nil - } - } + return iface.Addrs() +} + +func getMatchingAddr(ifname string, matches func(net.IP) bool) (net.IP, error) { + ifaddrs, err := InterfaceAddresses(ifname) + if err != nil { + return nil, err + } + for _, ifaddr := range ifaddrs { + if ifaddr, ok := ifaddr.(*net.IPNet); ok && matches(ifaddr.IP) { + return ifaddr.IP, nil } } - return nil, fmt.Errorf("No link-local address found for interface %v", ifname) + return nil, fmt.Errorf("no matching address found for interface %s", ifname) +} + +// GetLinkLocalAddr returns a link-local address for the interface +func GetLinkLocalAddr(ifname string) (net.IP, error) { + return getMatchingAddr(ifname, func(ip net.IP) bool { + return ip.To4() == nil && ip.IsLinkLocalUnicast() + }) +} + +// GetGlobalAddr returns a global address for the interface +func GetGlobalAddr(ifname string) (net.IP, error) { + return getMatchingAddr(ifname, func(ip net.IP) bool { + return ip.To4() == nil && ip.IsGlobalUnicast() + }) } diff --git a/dhcpv6/iputils_test.go b/dhcpv6/iputils_test.go new file mode 100644 index 0000000..765792e --- /dev/null +++ b/dhcpv6/iputils_test.go @@ -0,0 +1,103 @@ +package dhcpv6 + +import ( + "errors" + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +var ErrDummy = errors.New("dummy error") + +type MatchingAddressTestSuite struct { + suite.Suite + m mock.Mock + + ips []net.IP + addrs []net.Addr +} + +func (s *MatchingAddressTestSuite) InterfaceAddresses(name string) ([]net.Addr, error) { + args := s.m.Called(name) + if args.Get(0) == nil { + return nil, args.Error(1) + } + if ifaddrs, ok := args.Get(0).([]net.Addr); ok { + return ifaddrs, args.Error(1) + } + panic(fmt.Sprintf("assert: arguments: InterfaceAddresses(0) failed because object wasn't correct type: %v", args.Get(0))) +} + +func (s *MatchingAddressTestSuite) Match(ip net.IP) bool { + args := s.m.Called(ip) + return args.Bool(0) +} + +func (s *MatchingAddressTestSuite) SetupTest() { + InterfaceAddresses = s.InterfaceAddresses + s.ips = []net.IP{ + net.ParseIP("2401:db00:3020:70e1:face:0:7e:0"), + net.ParseIP("2803:6080:890c:847e::1"), + net.ParseIP("fe80::4a57:ddff:fe04:d8e9"), + } + s.addrs = []net.Addr{} + for _, ip := range s.ips { + s.addrs = append(s.addrs, &net.IPNet{IP: ip}) + } +} + +func (s *MatchingAddressTestSuite) TestGetMatchingAddr() { + // Check if error from InterfaceAddresses immidately returns error + s.m.On("InterfaceAddresses", "eth0").Return(nil, ErrDummy).Once() + _, err := getMatchingAddr("eth0", s.Match) + s.Assert().Equal(ErrDummy, err) + s.m.AssertExpectations(s.T()) + // Check if the looping is stopped after finding a matching address + s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once() + s.m.On("Match", s.ips[0]).Return(false).Once() + s.m.On("Match", s.ips[1]).Return(true).Once() + ip, err := getMatchingAddr("eth0", s.Match) + s.Require().NoError(err) + s.Assert().Equal(s.ips[1], ip) + s.m.AssertExpectations(s.T()) + // Check if the looping skips not matching addresses + s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once() + s.m.On("Match", s.ips[0]).Return(false).Once() + s.m.On("Match", s.ips[1]).Return(false).Once() + s.m.On("Match", s.ips[2]).Return(true).Once() + ip, err = getMatchingAddr("eth0", s.Match) + s.Require().NoError(err) + s.Assert().Equal(s.ips[2], ip) + s.m.AssertExpectations(s.T()) + // Check if the error is returned if no matching address is found + s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once() + s.m.On("Match", s.ips[0]).Return(false).Once() + s.m.On("Match", s.ips[1]).Return(false).Once() + s.m.On("Match", s.ips[2]).Return(false).Once() + _, err = getMatchingAddr("eth0", s.Match) + s.Assert().EqualError(err, "no matching address found for interface eth0") + s.m.AssertExpectations(s.T()) +} + +func (s *MatchingAddressTestSuite) TestGetLinkLocalAddr() { + s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once() + ip, err := GetLinkLocalAddr("eth0") + s.Require().NoError(err) + s.Assert().Equal(s.ips[2], ip) + s.m.AssertExpectations(s.T()) +} + +func (s *MatchingAddressTestSuite) TestGetGlobalAddr() { + s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once() + ip, err := GetGlobalAddr("eth0") + s.Require().NoError(err) + s.Assert().Equal(s.ips[0], ip) + s.m.AssertExpectations(s.T()) +} + +func TestMatchingAddressTestSuite(t *testing.T) { + suite.Run(t, new(MatchingAddressTestSuite)) +} |