diff options
-rw-r--r-- | dhcpv6/dhcpv6relay.go | 51 | ||||
-rw-r--r-- | dhcpv6/dhcpv6relay_test.go | 34 |
2 files changed, 85 insertions, 0 deletions
diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go index 68f3cf0..fb2a0c1 100644 --- a/dhcpv6/dhcpv6relay.go +++ b/dhcpv6/dhcpv6relay.go @@ -1,6 +1,7 @@ package dhcpv6 import ( + "errors" "fmt" "net" ) @@ -180,3 +181,53 @@ func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) { } return addr, nil } + +// NewRelayReplFromRelayForw creates a RELAY_REPL packet based on a RELAY_FORW +// packet and replaces the inner message with the passed DHCPv6 message. +func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) { + var ( + err error + linkAddr, peerAddr []net.IP + optiids []Option + ) + if relayForw == nil { + return nil, errors.New("RELAY_FORW cannot be nil") + } + relay, ok := relayForw.(*DHCPv6Relay) + if !ok { + return nil, errors.New("Not a DHCPv6Relay") + } + if relay.Type() != RELAY_FORW { + return nil, errors.New("The passed packet is not of type RELAY_FORW") + } + if msg == nil { + return nil, errors.New("The passed message cannot be nil") + } + if msg.IsRelay() { + return nil, errors.New("The passed message cannot be a relay") + } + for { + linkAddr = append(linkAddr, relay.LinkAddr()) + peerAddr = append(peerAddr, relay.PeerAddr()) + optiids = append(optiids, relay.GetOneOption(OPTION_INTERFACE_ID)) + decap, err := DecapsulateRelay(relay) + if err != nil { + return nil, err + } + if decap.IsRelay() { + relay = decap.(*DHCPv6Relay) + } else { + break + } + } + for i := len(linkAddr) - 1; i >= 0; i-- { + msg, err = EncapsulateRelay(msg, RELAY_REPL, linkAddr[i], peerAddr[i]) + if opt := optiids[i]; opt != nil { + msg.AddOption(opt) + } + if err != nil { + return nil, err + } + } + return msg, nil +} diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go index c5989e5..e08e6a4 100644 --- a/dhcpv6/dhcpv6relay_test.go +++ b/dhcpv6/dhcpv6relay_test.go @@ -4,6 +4,8 @@ import ( "bytes" "net" "testing" + + "github.com/stretchr/testify/require" ) func TestDHCPv6Relay(t *testing.T) { @@ -104,3 +106,35 @@ func TestDHCPv6RelayToBytes(t *testing.T) { t.Fatalf("Invalid ToBytes result. Expected %v, got %v", expected, relayBytes) } } + +func TestNewRelayRepFromRelayForw(t *testing.T) { + rf := DHCPv6Relay{} + rf.SetMessageType(RELAY_FORW) + rf.SetPeerAddr(net.IPv6linklocalallrouters) + rf.SetLinkAddr(net.IPv6interfacelocalallnodes) + oro := OptRelayMsg{} + s := DHCPv6Message{} + s.SetMessage(SOLICIT) + cid := OptClientId{} + s.AddOption(&cid) + oro.SetRelayMessage(&s) + rf.AddOption(&oro) + + a, err := NewAdvertiseFromSolicit(&s) + require.NoError(t, err) + rr, err := NewRelayReplFromRelayForw(&rf, a) + require.NoError(t, err) + relay := rr.(*DHCPv6Relay) + require.Equal(t, rr.Type(), RELAY_REPL) + require.Equal(t, relay.HopCount(), rf.HopCount()) + require.Equal(t, relay.PeerAddr(), rf.PeerAddr()) + require.Equal(t, relay.LinkAddr(), rf.LinkAddr()) + m, err := relay.GetInnerMessage() + require.NoError(t, err) + require.Equal(t, m, a) + + rr, err = NewRelayReplFromRelayForw(nil, a) + require.Error(t, err) + rr, err = NewRelayReplFromRelayForw(&rf, nil) + require.Error(t, err) +} |