summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--dhcpv6/dhcpv6relay.go51
-rw-r--r--dhcpv6/dhcpv6relay_test.go34
2 files changed, 85 insertions, 0 deletions
diff --git a/dhcpv6/dhcpv6relay.go b/dhcpv6/dhcpv6relay.go
index 68f3cf0..fb2a0c1 100644
--- a/dhcpv6/dhcpv6relay.go
+++ b/dhcpv6/dhcpv6relay.go
@@ -1,6 +1,7 @@
package dhcpv6
import (
+ "errors"
"fmt"
"net"
)
@@ -180,3 +181,53 @@ func (r *DHCPv6Relay) GetInnerPeerAddr() (net.IP, error) {
}
return addr, nil
}
+
+// NewRelayReplFromRelayForw creates a RELAY_REPL packet based on a RELAY_FORW
+// packet and replaces the inner message with the passed DHCPv6 message.
+func NewRelayReplFromRelayForw(relayForw, msg DHCPv6) (DHCPv6, error) {
+ var (
+ err error
+ linkAddr, peerAddr []net.IP
+ optiids []Option
+ )
+ if relayForw == nil {
+ return nil, errors.New("RELAY_FORW cannot be nil")
+ }
+ relay, ok := relayForw.(*DHCPv6Relay)
+ if !ok {
+ return nil, errors.New("Not a DHCPv6Relay")
+ }
+ if relay.Type() != RELAY_FORW {
+ return nil, errors.New("The passed packet is not of type RELAY_FORW")
+ }
+ if msg == nil {
+ return nil, errors.New("The passed message cannot be nil")
+ }
+ if msg.IsRelay() {
+ return nil, errors.New("The passed message cannot be a relay")
+ }
+ for {
+ linkAddr = append(linkAddr, relay.LinkAddr())
+ peerAddr = append(peerAddr, relay.PeerAddr())
+ optiids = append(optiids, relay.GetOneOption(OPTION_INTERFACE_ID))
+ decap, err := DecapsulateRelay(relay)
+ if err != nil {
+ return nil, err
+ }
+ if decap.IsRelay() {
+ relay = decap.(*DHCPv6Relay)
+ } else {
+ break
+ }
+ }
+ for i := len(linkAddr) - 1; i >= 0; i-- {
+ msg, err = EncapsulateRelay(msg, RELAY_REPL, linkAddr[i], peerAddr[i])
+ if opt := optiids[i]; opt != nil {
+ msg.AddOption(opt)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+ return msg, nil
+}
diff --git a/dhcpv6/dhcpv6relay_test.go b/dhcpv6/dhcpv6relay_test.go
index c5989e5..e08e6a4 100644
--- a/dhcpv6/dhcpv6relay_test.go
+++ b/dhcpv6/dhcpv6relay_test.go
@@ -4,6 +4,8 @@ import (
"bytes"
"net"
"testing"
+
+ "github.com/stretchr/testify/require"
)
func TestDHCPv6Relay(t *testing.T) {
@@ -104,3 +106,35 @@ func TestDHCPv6RelayToBytes(t *testing.T) {
t.Fatalf("Invalid ToBytes result. Expected %v, got %v", expected, relayBytes)
}
}
+
+func TestNewRelayRepFromRelayForw(t *testing.T) {
+ rf := DHCPv6Relay{}
+ rf.SetMessageType(RELAY_FORW)
+ rf.SetPeerAddr(net.IPv6linklocalallrouters)
+ rf.SetLinkAddr(net.IPv6interfacelocalallnodes)
+ oro := OptRelayMsg{}
+ s := DHCPv6Message{}
+ s.SetMessage(SOLICIT)
+ cid := OptClientId{}
+ s.AddOption(&cid)
+ oro.SetRelayMessage(&s)
+ rf.AddOption(&oro)
+
+ a, err := NewAdvertiseFromSolicit(&s)
+ require.NoError(t, err)
+ rr, err := NewRelayReplFromRelayForw(&rf, a)
+ require.NoError(t, err)
+ relay := rr.(*DHCPv6Relay)
+ require.Equal(t, rr.Type(), RELAY_REPL)
+ require.Equal(t, relay.HopCount(), rf.HopCount())
+ require.Equal(t, relay.PeerAddr(), rf.PeerAddr())
+ require.Equal(t, relay.LinkAddr(), rf.LinkAddr())
+ m, err := relay.GetInnerMessage()
+ require.NoError(t, err)
+ require.Equal(t, m, a)
+
+ rr, err = NewRelayReplFromRelayForw(nil, a)
+ require.Error(t, err)
+ rr, err = NewRelayReplFromRelayForw(&rf, nil)
+ require.Error(t, err)
+}