diff options
-rw-r--r-- | dhcpv6/dhcpv6relay.go | 36 | ||||
-rw-r--r-- | dhcpv6/dhcpv6relay_test.go | 25 |
2 files changed, 44 insertions, 17 deletions
diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go index fb2a0c1..527eddc 100644 --- a/dhcpv6/dhcpv6relay.go +++ b/dhcpv6/dhcpv6relay.go @@ -158,28 +158,30 @@ func (d *DHCPv6Relay) GetInnerMessage() (DHCPv6, error) { } } -// GetInnerPeerAddr returns the peer address in the inner most relay info -// header, this is typically the IP address of the client making the request. -func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) { - var ( - p DHCPv6 - err error - ) - p = r - hops := r.HopCount() - addr := r.PeerAddr() - for i := uint8(0); i < hops; i++ { - p, err = DecapsulateRelay(p) +// Recurse into a relay message and extract and return the inner DHCPv6Relay. +// Return nil if none found (e.g. not a relay message). +func (r *DHCPv6Relay) GetInnerRelay() (DHCPv6, error) { + p := r + for { + d, err := DecapsulateRelay(p) if err != nil { return nil, err } - if p.IsRelay() { - addr = p.(*DHCPv6Relay).PeerAddr() - } else { - return nil, fmt.Errorf("Wrong Hop count") + if !d.IsRelay() { + return p, nil } + p = d.(*DHCPv6Relay) + } +} + +// GetInnerPeerAddr returns the peer address in the inner most relay info +// header, this is typically the IP address of the client making the request. +func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) { + p, err := r.GetInnerRelay() + if err != nil { + return nil, err } - return addr, nil + return p.(*DHCPv6Relay).PeerAddr(), nil } // NewRelayReplFromRelayForw creates a RELAY_REPL packet based on a RELAY_FORW diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go index e08e6a4..3effb35 100644 --- a/dhcpv6/dhcpv6relay_test.go +++ b/dhcpv6/dhcpv6relay_test.go @@ -107,6 +107,31 @@ func TestDHCPv6RelayToBytes(t *testing.T) { } } +func TestGetInnerRelay(t *testing.T) { + m := DHCPv6Message{} + r1, err := EncapsulateRelay(&m, RELAY_FORW, net.IPv6linklocalallnodes, net.IPv6interfacelocalallnodes) + require.NoError(t, err) + r2, err := EncapsulateRelay(r1, RELAY_FORW, net.IPv6loopback, net.IPv6linklocalallnodes) + require.NoError(t, err) + r3, err := EncapsulateRelay(r2, RELAY_FORW, net.IPv6unspecified, net.IPv6linklocalallrouters) + require.NoError(t, err) + + relay3, ok := r3.(*DHCPv6Relay) + require.True(t, ok) + + ir, err := relay3.GetInnerRelay() + require.NoError(t, err) + relay, ok := ir.(*DHCPv6Relay) + require.True(t, ok) + require.Equal(t, relay.HopCount(), uint8(0)) + require.Equal(t, relay.LinkAddr(), net.IPv6linklocalallnodes) + require.Equal(t, relay.PeerAddr(), net.IPv6interfacelocalallnodes) + + innerPeerAddr, err := relay3.GetInnerPeerAddr() + require.NoError(t, err) + require.Equal(t, innerPeerAddr, net.IPv6interfacelocalallnodes) +} + func TestNewRelayRepFromRelayForw(t *testing.T) { rf := DHCPv6Relay{} rf.SetMessageType(RELAY_FORW) |