summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv6/async/client.go2
-rw-r--r--dhcpv6/client.go2
-rw-r--r--dhcpv6/iputils.go48
-rw-r--r--dhcpv6/iputils_test.go103
4 files changed, 136 insertions, 19 deletions
diff --git a/dhcpv6/async/client.go b/dhcpv6/async/client.go
index 08c2cfb..c574208 100644
--- a/dhcpv6/async/client.go
+++ b/dhcpv6/async/client.go
@@ -45,7 +45,7 @@ func (c *Client) OpenForInterface(ifname string, bufferSize int) error {
if err != nil {
return err
}
- c.LocalAddr = &net.UDPAddr{IP: *addr, Port: dhcpv6.DefaultClientPort, Zone: ifname}
+ c.LocalAddr = &net.UDPAddr{IP: addr, Port: dhcpv6.DefaultClientPort, Zone: ifname}
return c.Open(bufferSize)
}
diff --git a/dhcpv6/client.go b/dhcpv6/client.go
index 3ed7861..10a20c9 100644
--- a/dhcpv6/client.go
+++ b/dhcpv6/client.go
@@ -100,7 +100,7 @@ func (c *Client) sendReceive(ifname string, packet DHCPv6, expectedType MessageT
if err != nil {
return nil, err
}
- laddr = net.UDPAddr{IP: *llAddr, Port: DefaultClientPort, Zone: ifname}
+ laddr = net.UDPAddr{IP: llAddr, Port: DefaultClientPort, Zone: ifname}
} else {
if addr, ok := c.LocalAddr.(*net.UDPAddr); ok {
laddr = *addr
diff --git a/dhcpv6/iputils.go b/dhcpv6/iputils.go
index c3ac3aa..2b9788c 100644
--- a/dhcpv6/iputils.go
+++ b/dhcpv6/iputils.go
@@ -5,26 +5,40 @@ import (
"net"
)
-func GetLinkLocalAddr(ifname string) (*net.IP, error) {
- ifaces, err := net.Interfaces()
+// InterfaceAddresses is used to fetch addresses of an interface with given name
+var InterfaceAddresses func(string) ([]net.Addr, error) = interfaceAddresses
+
+func interfaceAddresses(ifname string) ([]net.Addr, error) {
+ iface, err := net.InterfaceByName(ifname)
if err != nil {
return nil, err
}
- for _, iface := range ifaces {
- if iface.Name != ifname {
- continue
- }
- ifaddrs, err := iface.Addrs()
- if err != nil {
- return nil, err
- }
- for _, ifaddr := range ifaddrs {
- if ifaddr, ok := ifaddr.(*net.IPNet); ok {
- if ifaddr.IP.To4() == nil && ifaddr.IP.IsLinkLocalUnicast() {
- return &ifaddr.IP, nil
- }
- }
+ return iface.Addrs()
+}
+
+func getMatchingAddr(ifname string, matches func(net.IP) bool) (net.IP, error) {
+ ifaddrs, err := InterfaceAddresses(ifname)
+ if err != nil {
+ return nil, err
+ }
+ for _, ifaddr := range ifaddrs {
+ if ifaddr, ok := ifaddr.(*net.IPNet); ok && matches(ifaddr.IP) {
+ return ifaddr.IP, nil
}
}
- return nil, fmt.Errorf("No link-local address found for interface %v", ifname)
+ return nil, fmt.Errorf("no matching address found for interface %s", ifname)
+}
+
+// GetLinkLocalAddr returns a link-local address for the interface
+func GetLinkLocalAddr(ifname string) (net.IP, error) {
+ return getMatchingAddr(ifname, func(ip net.IP) bool {
+ return ip.To4() == nil && ip.IsLinkLocalUnicast()
+ })
+}
+
+// GetGlobalAddr returns a global address for the interface
+func GetGlobalAddr(ifname string) (net.IP, error) {
+ return getMatchingAddr(ifname, func(ip net.IP) bool {
+ return ip.To4() == nil && ip.IsGlobalUnicast()
+ })
}
diff --git a/dhcpv6/iputils_test.go b/dhcpv6/iputils_test.go
new file mode 100644
index 0000000..765792e
--- /dev/null
+++ b/dhcpv6/iputils_test.go
@@ -0,0 +1,103 @@
+package dhcpv6
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/suite"
+)
+
+var ErrDummy = errors.New("dummy error")
+
+type MatchingAddressTestSuite struct {
+ suite.Suite
+ m mock.Mock
+
+ ips []net.IP
+ addrs []net.Addr
+}
+
+func (s *MatchingAddressTestSuite) InterfaceAddresses(name string) ([]net.Addr, error) {
+ args := s.m.Called(name)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ if ifaddrs, ok := args.Get(0).([]net.Addr); ok {
+ return ifaddrs, args.Error(1)
+ }
+ panic(fmt.Sprintf("assert: arguments: InterfaceAddresses(0) failed because object wasn't correct type: %v", args.Get(0)))
+}
+
+func (s *MatchingAddressTestSuite) Match(ip net.IP) bool {
+ args := s.m.Called(ip)
+ return args.Bool(0)
+}
+
+func (s *MatchingAddressTestSuite) SetupTest() {
+ InterfaceAddresses = s.InterfaceAddresses
+ s.ips = []net.IP{
+ net.ParseIP("2401:db00:3020:70e1:face:0:7e:0"),
+ net.ParseIP("2803:6080:890c:847e::1"),
+ net.ParseIP("fe80::4a57:ddff:fe04:d8e9"),
+ }
+ s.addrs = []net.Addr{}
+ for _, ip := range s.ips {
+ s.addrs = append(s.addrs, &net.IPNet{IP: ip})
+ }
+}
+
+func (s *MatchingAddressTestSuite) TestGetMatchingAddr() {
+ // Check if error from InterfaceAddresses immidately returns error
+ s.m.On("InterfaceAddresses", "eth0").Return(nil, ErrDummy).Once()
+ _, err := getMatchingAddr("eth0", s.Match)
+ s.Assert().Equal(ErrDummy, err)
+ s.m.AssertExpectations(s.T())
+ // Check if the looping is stopped after finding a matching address
+ s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once()
+ s.m.On("Match", s.ips[0]).Return(false).Once()
+ s.m.On("Match", s.ips[1]).Return(true).Once()
+ ip, err := getMatchingAddr("eth0", s.Match)
+ s.Require().NoError(err)
+ s.Assert().Equal(s.ips[1], ip)
+ s.m.AssertExpectations(s.T())
+ // Check if the looping skips not matching addresses
+ s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once()
+ s.m.On("Match", s.ips[0]).Return(false).Once()
+ s.m.On("Match", s.ips[1]).Return(false).Once()
+ s.m.On("Match", s.ips[2]).Return(true).Once()
+ ip, err = getMatchingAddr("eth0", s.Match)
+ s.Require().NoError(err)
+ s.Assert().Equal(s.ips[2], ip)
+ s.m.AssertExpectations(s.T())
+ // Check if the error is returned if no matching address is found
+ s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once()
+ s.m.On("Match", s.ips[0]).Return(false).Once()
+ s.m.On("Match", s.ips[1]).Return(false).Once()
+ s.m.On("Match", s.ips[2]).Return(false).Once()
+ _, err = getMatchingAddr("eth0", s.Match)
+ s.Assert().EqualError(err, "no matching address found for interface eth0")
+ s.m.AssertExpectations(s.T())
+}
+
+func (s *MatchingAddressTestSuite) TestGetLinkLocalAddr() {
+ s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once()
+ ip, err := GetLinkLocalAddr("eth0")
+ s.Require().NoError(err)
+ s.Assert().Equal(s.ips[2], ip)
+ s.m.AssertExpectations(s.T())
+}
+
+func (s *MatchingAddressTestSuite) TestGetGlobalAddr() {
+ s.m.On("InterfaceAddresses", "eth0").Return(s.addrs, nil).Once()
+ ip, err := GetGlobalAddr("eth0")
+ s.Require().NoError(err)
+ s.Assert().Equal(s.ips[0], ip)
+ s.m.AssertExpectations(s.T())
+}
+
+func TestMatchingAddressTestSuite(t *testing.T) {
+ suite.Run(t, new(MatchingAddressTestSuite))
+}