summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/stack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/stack.go')
-rw-r--r--pkg/tcpip/stack/stack.go75
1 files changed, 62 insertions, 13 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 242d2150c..71e0618f4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -21,7 +21,9 @@ package stack
import (
"encoding/binary"
+ "math"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -48,6 +50,42 @@ const (
DefaultTOS = 0
)
+const (
+ // fakeNetNumber is used as a protocol number in tests.
+ //
+ // This constant should match fakeNetNumber in stack_test.go.
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+)
+
+type forwardingFlag uint32
+
+// Packet forwarding flags. Forwarding settings for different network protocols
+// are stored as bit flags in an uint32 number.
+const (
+ forwardingIPv4 forwardingFlag = 1 << iota
+ forwardingIPv6
+
+ // forwardingFake is used to test package forwarding with a fake protocol.
+ forwardingFake
+)
+
+func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
+ var flag forwardingFlag
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ flag = forwardingIPv4
+ case header.IPv6ProtocolNumber:
+ flag = forwardingIPv6
+ case fakeNetNumber:
+ // This network protocol number is used in stack_test to test
+ // packet forwarding.
+ flag = forwardingFake
+ default:
+ // We only support forwarding for IPv4 and IPv6.
+ }
+ return flag
+}
+
type transportProtocolState struct {
proto TransportProtocol
defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
@@ -363,7 +401,10 @@ type Stack struct {
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
- forwarding bool
+
+ // forwarding contains the enable bits for packet forwarding for different
+ // network protocols.
+ forwarding uint32
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -630,20 +671,28 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- s.forwarding = enable
- s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ flag := getForwardingFlag(protocol)
+ for {
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ var newValue forwardingFlag
+ if enable {
+ newValue = forwarding | flag
+ } else {
+ newValue = forwarding & ^flag
+ }
+ if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) {
+ break
+ }
+ }
}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+// Forwarding returns if packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ flag := getForwardingFlag(protocol)
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ return forwarding & flag != 0
}
// SetRouteTable assigns the route table to be used by this stack. It