summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/forwarder_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/forwarder_test.go')
-rw-r--r--pkg/tcpip/stack/forwarder_test.go45
1 files changed, 41 insertions, 4 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 8d18f3c8c..572a2c3b6 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,6 +46,8 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
+ AddressableEndpointState
+
nicID tcpip.NICID
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
@@ -53,6 +56,16 @@ type fwdTestNetworkEndpoint struct {
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+ return nil
+}
+
+func (*fwdTestNetworkEndpoint) Enabled() bool {
+ return true
+}
+
+func (*fwdTestNetworkEndpoint) Disable() {}
+
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
@@ -106,7 +119,9 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu
return tcpip.ErrNotSupported
}
-func (*fwdTestNetworkEndpoint) Close() {}
+func (f *fwdTestNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
@@ -116,6 +131,11 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
@@ -145,13 +165,15 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
- return &fwdTestNetworkEndpoint{
- nicID: nicID,
+func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
+ e := &fwdTestNetworkEndpoint{
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
ep: ep,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
@@ -186,6 +208,21 @@ func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress