summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-05-14 16:29:33 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-14 16:32:16 -0700
commitdf2352796d1cbe5eea563d54380be60be18455bc (patch)
treeca9135a78ec0131bf2a517f708c218e5d9d58ade
parent25f0ab3313c356fcfb9e4282eda3b2aa2278956d (diff)
Control forwarding per NetworkEndpoint
...instead of per NetworkProtocol to better conform with linux (https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt): ``` conf/interface/* forwarding - BOOLEAN Enable IP forwarding on this interface. This controls whether packets received _on_ this interface can be forwarded. ``` Fixes #5932. PiperOrigin-RevId: 373888000
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go76
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go82
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go52
-rw-r--r--pkg/tcpip/stack/forwarding_test.go18
-rw-r--r--pkg/tcpip/stack/nic.go29
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go145
-rw-r--r--pkg/tcpip/stack/stack_test.go16
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go248
12 files changed, 479 insertions, 205 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 049811cbb..23178277a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -63,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -82,6 +88,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -151,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -852,7 +882,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1144,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1165,12 +1194,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1283,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4051fda07..307e1972d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f0e06f86b..95e11ac51 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -187,6 +193,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -405,20 +417,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
// As per RFC 4291 section 2.8:
//
@@ -1109,7 +1139,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1932,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1957,12 +1986,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -2137,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index b29fed347..f0ff111c5 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -705,7 +705,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -1710,7 +1710,7 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
return
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 570c6c00c..234e34952 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
+ const nicID = 1
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
var extHdrs header.IPv6ExtHdrSerializer
@@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) {
},
}
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+
for _, typ := range types {
for _, isRouter := range []bool{false, true} {
name := typ.name
@@ -875,7 +872,10 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
if isRouter {
if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
@@ -883,6 +883,24 @@ func TestNDPValidation(t *testing.T) {
}
}
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatal("cannot find network endpoint instance for IPv6")
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }})
+
stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
@@ -907,12 +925,12 @@ func TestNDPValidation(t *testing.T) {
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ t.Errorf("got invalid.Value() = %d, want = 0", got)
}
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ // Should initially not have dropped any packets.
if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ t.Errorf("got routerOnly.Value() = %d, want = 0", got)
}
if t.Failed() {
@@ -932,18 +950,18 @@ func TestNDPValidation(t *testing.T) {
want = 1
}
if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ t.Errorf("got invalid.Value() = %d, want = %d", got, want)
}
want = 0
if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
+ // Router only packets are expected to be dropped when operating
+ // as a host.
want = 1
}
if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
}
-
})
}
})
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ff555722e..7107d598d 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct {
nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
@@ -169,11 +174,6 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -242,16 +242,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8d615500f..dbba2c79f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t
return d.CheckDuplicateAddress(addr, h), nil
}
+
+func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ forwardingEP.SetForwarding(enable)
+ return nil
+}
+
+func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return false, &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return false, &tcpip.ErrNotSupported{}
+ }
+
+ return forwardingEP.Forwarding(), nil
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a82c807b4..85bb87b4b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface {
IPStats() *tcpip.IPStats
}
-// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
-type ForwardingNetworkProtocol interface {
- NetworkProtocol
+// ForwardingNetworkEndpoint is a network endpoint that may forward packets.
+type ForwardingNetworkEndpoint interface {
+ NetworkEndpoint
// Forwarding returns the forwarding configuration.
Forwarding() bool
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 8a044c073..f17c04277 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool {
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
- if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
+ if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 483a960c8..8814f45a6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -95,8 +95,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -348,22 +349,23 @@ func New(opts Options) *Stack {
}
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -492,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
-// passed protocol and sets the default setting for newly created NICs.
-func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -659,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -786,6 +817,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -815,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
@@ -825,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1029,6 +1080,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1081,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1120,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ff88b1bd3..02d54d29b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct {
mu struct {
sync.RWMutex
- enabled bool
+ enabled bool
+ forwarding bool
}
nic stack.NetworkInterface
@@ -227,11 +228,6 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -300,15 +296,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 42bc53328..92fa6257d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -16,6 +16,7 @@ package forward_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -34,6 +35,39 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const ttl = 64
+
+var (
+ ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
+ ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
+)
+
+func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+}
+
+func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+}
+
+func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+}
+
+func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+}
+
func TestForwarding(t *testing.T) {
const listenPort = 8080
@@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) {
const (
nicID1 = 1
nicID2 = 2
- ttl = 64
)
var (
ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
)
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
- }
-
- v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
- }
-
tests := []struct {
name string
srcAddr, dstAddr tcpip.Address
@@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
},
},
{
@@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv4EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
},
},
@@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
},
},
{
@@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) {
rx: rxICMPv6EchoRequest,
expectForward: true,
checker: func(t *testing.T, b []byte) {
- v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
},
},
}
@@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) {
})
}
}
+
+func TestPerInterfaceForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ checker: func(t *testing.T, b []byte) {
+ forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ // ARP is not used in this test but it is a network protocol that does
+ // not support forwarding. We install the protocol to make sure that
+ // forwarding information for a NIC is only reported for network
+ // protocols that support forwarding.
+ arp.NewProtocol,
+
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ for _, add := range [...]struct {
+ nicID tcpip.NICID
+ addr tcpip.ProtocolAddress
+ }{
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv4Addr,
+ },
+ {
+ nicID: nicID1,
+ addr: utils.RouterNIC1IPv6Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv4Addr,
+ },
+ {
+ nicID: nicID2,
+ addr: utils.RouterNIC2IPv6Addr,
+ },
+ } {
+ if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ }
+ }
+
+ // Only enable forwarding on NIC1 and make sure that only packets arriving
+ // on NIC1 are forwarded.
+ for _, netProto := range netProtos {
+ if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
+ t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
+ }
+ }
+
+ nicsInfo := s.NICInfo()
+ for _, subTest := range [...]struct {
+ nicID tcpip.NICID
+ nicEP *channel.Endpoint
+ otherNICID tcpip.NICID
+ otherNICEP *channel.Endpoint
+ expectForwarding bool
+ }{
+ {
+ nicID: nicID1,
+ nicEP: e1,
+ otherNICID: nicID2,
+ otherNICEP: e2,
+ expectForwarding: true,
+ },
+ {
+ nicID: nicID2,
+ nicEP: e2,
+ otherNICID: nicID2,
+ otherNICEP: e1,
+ expectForwarding: false,
+ },
+ } {
+ t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
+ nicInfo, ok := nicsInfo[subTest.nicID]
+ if !ok {
+ t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
+ } else {
+ forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
+ for _, netProto := range netProtos {
+ forwarding[netProto] = subTest.expectForwarding
+ }
+
+ if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
+ t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: subTest.otherNICID,
+ },
+ })
+
+ test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
+ if p, ok := subTest.nicEP.Read(); ok {
+ t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
+ }
+ if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
+ t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
+ } else if subTest.expectForwarding {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+ })
+ }
+}