summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv6/dhcpv6relay.go36
-rw-r--r--dhcpv6/dhcpv6relay_test.go25
2 files changed, 44 insertions, 17 deletions
diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go
index fb2a0c1..527eddc 100644
--- a/dhcpv6/dhcpv6relay.go
+++ b/dhcpv6/dhcpv6relay.go
@@ -158,28 +158,30 @@ func (d *DHCPv6Relay) GetInnerMessage() (DHCPv6, error) {
}
}
-// GetInnerPeerAddr returns the peer address in the inner most relay info
-// header, this is typically the IP address of the client making the request.
-func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) {
- var (
- p DHCPv6
- err error
- )
- p = r
- hops := r.HopCount()
- addr := r.PeerAddr()
- for i := uint8(0); i < hops; i++ {
- p, err = DecapsulateRelay(p)
+// Recurse into a relay message and extract and return the inner DHCPv6Relay.
+// Return nil if none found (e.g. not a relay message).
+func (r *DHCPv6Relay) GetInnerRelay() (DHCPv6, error) {
+ p := r
+ for {
+ d, err := DecapsulateRelay(p)
if err != nil {
return nil, err
}
- if p.IsRelay() {
- addr = p.(*DHCPv6Relay).PeerAddr()
- } else {
- return nil, fmt.Errorf("Wrong Hop count")
+ if !d.IsRelay() {
+ return p, nil
}
+ p = d.(*DHCPv6Relay)
+ }
+}
+
+// GetInnerPeerAddr returns the peer address in the inner most relay info
+// header, this is typically the IP address of the client making the request.
+func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) {
+ p, err := r.GetInnerRelay()
+ if err != nil {
+ return nil, err
}
- return addr, nil
+ return p.(*DHCPv6Relay).PeerAddr(), nil
}
// NewRelayReplFromRelayForw creates a RELAY_REPL packet based on a RELAY_FORW
diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go
index e08e6a4..3effb35 100644
--- a/dhcpv6/dhcpv6relay_test.go
+++ b/dhcpv6/dhcpv6relay_test.go
@@ -107,6 +107,31 @@ func TestDHCPv6RelayToBytes(t *testing.T) {
}
}
+func TestGetInnerRelay(t *testing.T) {
+ m := DHCPv6Message{}
+ r1, err := EncapsulateRelay(&m, RELAY_FORW, net.IPv6linklocalallnodes, net.IPv6interfacelocalallnodes)
+ require.NoError(t, err)
+ r2, err := EncapsulateRelay(r1, RELAY_FORW, net.IPv6loopback, net.IPv6linklocalallnodes)
+ require.NoError(t, err)
+ r3, err := EncapsulateRelay(r2, RELAY_FORW, net.IPv6unspecified, net.IPv6linklocalallrouters)
+ require.NoError(t, err)
+
+ relay3, ok := r3.(*DHCPv6Relay)
+ require.True(t, ok)
+
+ ir, err := relay3.GetInnerRelay()
+ require.NoError(t, err)
+ relay, ok := ir.(*DHCPv6Relay)
+ require.True(t, ok)
+ require.Equal(t, relay.HopCount(), uint8(0))
+ require.Equal(t, relay.LinkAddr(), net.IPv6linklocalallnodes)
+ require.Equal(t, relay.PeerAddr(), net.IPv6interfacelocalallnodes)
+
+ innerPeerAddr, err := relay3.GetInnerPeerAddr()
+ require.NoError(t, err)
+ require.Equal(t, innerPeerAddr, net.IPv6interfacelocalallnodes)
+}
+
func TestNewRelayRepFromRelayForw(t *testing.T) {
rf := DHCPv6Relay{}
rf.SetMessageType(RELAY_FORW)