summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-05-14 16:29:33 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-14 16:32:16 -0700
commitdf2352796d1cbe5eea563d54380be60be18455bc (patch)
treeca9135a78ec0131bf2a517f708c218e5d9d58ade /pkg/tcpip/network
parent25f0ab3313c356fcfb9e4282eda3b2aa2278956d (diff)
Control forwarding per NetworkEndpoint
...instead of per NetworkProtocol to better conform with linux (https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt): ``` conf/interface/* forwarding - BOOLEAN Enable IP forwarding on this interface. This controls whether packets received _on_ this interface can be forwarded. ``` Fixes #5932. PiperOrigin-RevId: 373888000
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go76
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go82
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go52
5 files changed, 111 insertions, 109 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 049811cbb..23178277a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -63,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -82,6 +88,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -151,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -852,7 +882,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1144,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1165,12 +1194,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1283,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4051fda07..307e1972d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f0e06f86b..95e11ac51 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -187,6 +193,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -405,20 +417,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
// As per RFC 4291 section 2.8:
//
@@ -1109,7 +1139,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1932,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1957,12 +1986,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -2137,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index b29fed347..f0ff111c5 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -705,7 +705,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -1710,7 +1710,7 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
return
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 570c6c00c..234e34952 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
+ const nicID = 1
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
var extHdrs header.IPv6ExtHdrSerializer
@@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) {
},
}
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+
for _, typ := range types {
for _, isRouter := range []bool{false, true} {
name := typ.name
@@ -875,7 +872,10 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
if isRouter {
if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
@@ -883,6 +883,24 @@ func TestNDPValidation(t *testing.T) {
}
}
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatal("cannot find network endpoint instance for IPv6")
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }})
+
stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
@@ -907,12 +925,12 @@ func TestNDPValidation(t *testing.T) {
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ t.Errorf("got invalid.Value() = %d, want = 0", got)
}
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ // Should initially not have dropped any packets.
if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ t.Errorf("got routerOnly.Value() = %d, want = 0", got)
}
if t.Failed() {
@@ -932,18 +950,18 @@ func TestNDPValidation(t *testing.T) {
want = 1
}
if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ t.Errorf("got invalid.Value() = %d, want = %d", got, want)
}
want = 0
if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
+ // Router only packets are expected to be dropped when operating
+ // as a host.
want = 1
}
if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
}
-
})
}
})