diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-05-14 16:29:33 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-05-14 16:32:16 -0700 |
commit | df2352796d1cbe5eea563d54380be60be18455bc (patch) | |
tree | ca9135a78ec0131bf2a517f708c218e5d9d58ade /pkg/tcpip | |
parent | 25f0ab3313c356fcfb9e4282eda3b2aa2278956d (diff) |
Control forwarding per NetworkEndpoint
...instead of per NetworkProtocol to better conform with linux
(https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt):
```
conf/interface/*
forwarding - BOOLEAN
Enable IP forwarding on this interface. This controls whether packets
received _on_ this interface can be forwarded.
```
Fixes #5932.
PiperOrigin-RevId: 373888000
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 76 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 82 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 248 |
12 files changed, 479 insertions, 205 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 049811cbb..23178277a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -63,9 +63,15 @@ const ( fragmentblockSize = 8 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -82,6 +88,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -151,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { e.mu.Lock() defer e.mu.Unlock() + if !e.setForwarding(forwarding) { + return + } + if forwarding { // There does not seem to be an RFC requirement for a node to join the all // routers multicast address but @@ -852,7 +882,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) addressEndpoint.DecRef() pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -1144,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1165,12 +1194,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - ids []uint32 hashIV uint32 @@ -1283,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 4051fda07..307e1972d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - stack := e.protocol.stack - - // Is the networking stack operating as a router? - if !stack.Forwarding(ProtocolNumber) { - // ... No, silently drop the packet. + if !e.Forwarding() { received.routerOnlyPacketsDroppedByHost.Increment() return } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f0e06f86b..95e11ac51 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -63,6 +63,11 @@ const ( buckets = 2048 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + // policyTable is the default policy table defined in RFC 6724 section 2.1. // // A more human-readable version: @@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 { var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -187,6 +193,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -405,20 +417,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t } } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.setForwarding(forwarding) { + return + } + allRoutersGroups := [...]tcpip.Address{ header.IPv6AllRoutersInterfaceLocalMulticastAddress, header.IPv6AllRoutersLinkLocalMulticastAddress, header.IPv6AllRoutersSiteLocalMulticastAddress, } - e.mu.Lock() - defer e.mu.Unlock() - if forwarding { // As per RFC 4291 section 2.8: // @@ -1109,7 +1139,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.InvalidDestinationAddressesReceived.Increment() return } @@ -1932,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1957,12 +1986,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - fragmentation *fragmentation.Fragmentation } @@ -2137,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index b29fed347..f0ff111c5 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -705,7 +705,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -1710,7 +1710,7 @@ func (ndp *ndpState) startSolicitingRouters() { return } - if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { return } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 570c6c00c..234e34952 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { } func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } + const nicID = 1 handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { var extHdrs header.IPv6ExtHdrSerializer @@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) { }, } + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + for _, typ := range types { for _, isRouter := range []bool{false, true} { name := typ.name @@ -875,7 +872,10 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) if isRouter { if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { @@ -883,6 +883,24 @@ func TestNDPValidation(t *testing.T) { } } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatal("cannot find network endpoint instance for IPv6") + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}) + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost @@ -907,12 +925,12 @@ func TestNDPValidation(t *testing.T) { // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + t.Errorf("got invalid.Value() = %d, want = 0", got) } - // RouterOnlyPacketsReceivedByHost count should initially be 0. + // Should initially not have dropped any packets. if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + t.Errorf("got routerOnly.Value() = %d, want = 0", got) } if t.Failed() { @@ -932,18 +950,18 @@ func TestNDPValidation(t *testing.T) { want = 1 } if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + t.Errorf("got invalid.Value() = %d, want = %d", got, want) } want = 0 if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. + // Router only packets are expected to be dropped when operating + // as a host. want = 1 } if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) } - }) } }) diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index ff555722e..7107d598d 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct { nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + + mu struct { + sync.RWMutex + forwarding bool + } } func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { @@ -169,11 +174,6 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -242,16 +242,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8d615500f..dbba2c79f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t return d.CheckDuplicateAddress(addr, h), nil } + +func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return &tcpip.ErrNotSupported{} + } + + forwardingEP.SetForwarding(enable) + return nil +} + +func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return false, &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return false, &tcpip.ErrNotSupported{} + } + + return forwardingEP.Forwarding(), nil +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index a82c807b4..85bb87b4b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface { IPStats() *tcpip.IPStats } -// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. -type ForwardingNetworkProtocol interface { - NetworkProtocol +// ForwardingNetworkEndpoint is a network endpoint that may forward packets. +type ForwardingNetworkEndpoint interface { + NetworkEndpoint // Forwarding returns the forwarding configuration. Forwarding() bool diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 8a044c073..f17c04277 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool { // If the source NIC and outgoing NIC are different, make sure the stack has // forwarding enabled, or the packet will be handled locally. - if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { + if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { return false } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 483a960c8..8814f45a6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -95,8 +95,9 @@ type Stack struct { } } - mu sync.RWMutex - nics map[tcpip.NICID]*nic + mu sync.RWMutex + nics map[tcpip.NICID]*nic + defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{} // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -348,22 +349,23 @@ func New(opts Options) *Stack { } s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -492,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the -// passed protocol and sets the default setting for newly created NICs. -func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { - protocol, ok := s.networkProtocols[protocolNum] +// SetNICForwarding enables or disables packet forwarding on the specified NIC +// for the passed protocol. +func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrUnknownProtocol{} + return &tcpip.ErrUnknownNICID{} } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + return nic.setForwarding(protocol, enable) +} + +// NICForwarding returns the forwarding configuration for the specified NIC. +func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrNotSupported{} + return false, &tcpip.ErrUnknownNICID{} } - forwardingProtocol.SetForwarding(enable) - return nil + return nic.forwarding(protocol) } -// Forwarding returns true if packet forwarding between NICs is enabled for the -// passed protocol. -func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { - protocol, ok := s.networkProtocols[protocolNum] - if !ok { - return false +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + doneOnce := false + for id, nic := range s.nics { + if err := nic.setForwarding(protocol, enable); err != nil { + // Expect forwarding to be settable on all interfaces if it was set on + // one. + if doneOnce { + panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err)) + } + + return err + } + + doneOnce = true } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) - if !ok { - return false + if enable { + s.defaultForwardingEnabled[protocol] = struct{}{} + } else { + delete(s.defaultForwardingEnabled, protocol) } - return forwardingProtocol.Forwarding() + return nil } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in @@ -659,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) + for proto := range s.defaultForwardingEnabled { + if err := n.setForwarding(proto, true); err != nil { + panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err)) + } + } s.nics[id] = n if !opts.Disabled { return n.enable() @@ -786,6 +817,10 @@ type NICInfo struct { // value sent in haType field of an ARP Request sent by this NIC and the // value expected in the haType field of an ARP response. ARPHardwareType header.ARPHardwareType + + // Forwarding holds the forwarding status for each network endpoint that + // supports forwarding. + Forwarding map[tcpip.NetworkProtocolNumber]bool } // HasNIC returns true if the NICID is defined in the stack. @@ -815,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { netStats[proto] = netEP.Stats() } - nics[id] = NICInfo{ + info := NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), @@ -825,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), } + + for proto := range s.networkProtocols { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + info.Forwarding[proto] = forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } + } + + nics[id] = info } return nics } @@ -1029,6 +1080,20 @@ func (s *Stack) HandleLocal() bool { return s.handleLocal } +func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + return forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + return false + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1081,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal + onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1120,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // requirement to do this from any RFC but simply a choice made to better // follow a strong host model which the netstack follows at the time of // writing. - if canForward && chosenRoute == (tcpip.Route{}) { + if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) { chosenRoute = route } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index ff88b1bd3..02d54d29b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct { mu struct { sync.RWMutex - enabled bool + enabled bool + forwarding bool } nic stack.NetworkInterface @@ -227,11 +228,6 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -300,15 +296,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 42bc53328..92fa6257d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -16,6 +16,7 @@ package forward_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -34,6 +35,39 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) +} + +func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) +} + +func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) +} + +func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) +} + func TestForwarding(t *testing.T) { const listenPort = 8080 @@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 - ttl = 64 ) var ( ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") ) - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) - } - - v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) - } - tests := []struct { name string srcAddr, dstAddr tcpip.Address @@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) }, }, { @@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) }, }, @@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) }, }, { @@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) }, }, } @@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) { }) } } + +func TestPerInterfaceForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ) + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + // ARP is not used in this test but it is a network protocol that does + // not support forwarding. We install the protocol to make sure that + // forwarding information for a NIC is only reported for network + // protocols that support forwarding. + arp.NewProtocol, + + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + for _, add := range [...]struct { + nicID tcpip.NICID + addr tcpip.ProtocolAddress + }{ + { + nicID: nicID1, + addr: utils.RouterNIC1IPv4Addr, + }, + { + nicID: nicID1, + addr: utils.RouterNIC1IPv6Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv4Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv6Addr, + }, + } { + if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + } + } + + // Only enable forwarding on NIC1 and make sure that only packets arriving + // on NIC1 are forwarded. + for _, netProto := range netProtos { + if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { + t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) + } + } + + nicsInfo := s.NICInfo() + for _, subTest := range [...]struct { + nicID tcpip.NICID + nicEP *channel.Endpoint + otherNICID tcpip.NICID + otherNICEP *channel.Endpoint + expectForwarding bool + }{ + { + nicID: nicID1, + nicEP: e1, + otherNICID: nicID2, + otherNICEP: e2, + expectForwarding: true, + }, + { + nicID: nicID2, + nicEP: e2, + otherNICID: nicID2, + otherNICEP: e1, + expectForwarding: false, + }, + } { + t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { + nicInfo, ok := nicsInfo[subTest.nicID] + if !ok { + t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) + } else { + forwarding := make(map[tcpip.NetworkProtocolNumber]bool) + for _, netProto := range netProtos { + forwarding[netProto] = subTest.expectForwarding + } + + if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { + t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: subTest.otherNICID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: subTest.otherNICID, + }, + }) + + test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) + if p, ok := subTest.nicEP.Read(); ok { + t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) + } + if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { + t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) + } else if subTest.expectForwarding { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} |