summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-05-14 16:29:33 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-14 16:32:16 -0700
commitdf2352796d1cbe5eea563d54380be60be18455bc (patch)
treeca9135a78ec0131bf2a517f708c218e5d9d58ade /pkg/tcpip/stack
parent25f0ab3313c356fcfb9e4282eda3b2aa2278956d (diff)
Control forwarding per NetworkEndpoint
...instead of per NetworkProtocol to better conform with linux (https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt): ``` conf/interface/* forwarding - BOOLEAN Enable IP forwarding on this interface. This controls whether packets received _on_ this interface can be forwarded. ``` Fixes #5932. PiperOrigin-RevId: 373888000
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/forwarding_test.go18
-rw-r--r--pkg/tcpip/stack/nic.go29
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go145
-rw-r--r--pkg/tcpip/stack/stack_test.go16
6 files changed, 153 insertions, 63 deletions
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ff555722e..7107d598d 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct {
nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
@@ -169,11 +174,6 @@ type fwdTestNetworkProtocol struct {
addrResolveDelay time.Duration
onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -242,16 +242,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber
return fwdTestNetNumber
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8d615500f..dbba2c79f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t
return d.CheckDuplicateAddress(addr, h), nil
}
+
+func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ forwardingEP.SetForwarding(enable)
+ return nil
+}
+
+func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return false, &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return false, &tcpip.ErrNotSupported{}
+ }
+
+ return forwardingEP.Forwarding(), nil
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a82c807b4..85bb87b4b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface {
IPStats() *tcpip.IPStats
}
-// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
-type ForwardingNetworkProtocol interface {
- NetworkProtocol
+// ForwardingNetworkEndpoint is a network endpoint that may forward packets.
+type ForwardingNetworkEndpoint interface {
+ NetworkEndpoint
// Forwarding returns the forwarding configuration.
Forwarding() bool
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 8a044c073..f17c04277 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool {
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
- if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
+ if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 483a960c8..8814f45a6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -95,8 +95,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -348,22 +349,23 @@ func New(opts Options) *Stack {
}
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -492,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
-// passed protocol and sets the default setting for newly created NICs.
-func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -659,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -786,6 +817,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -815,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
@@ -825,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1029,6 +1080,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1081,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1120,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ff88b1bd3..02d54d29b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct {
mu struct {
sync.RWMutex
- enabled bool
+ enabled bool
+ forwarding bool
}
nic stack.NetworkInterface
@@ -227,11 +228,6 @@ type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
}
func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -300,15 +296,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) Forwarding() bool {
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) Forwarding() bool {
f.mu.RLock()
defer f.mu.RUnlock()
return f.mu.forwarding
}
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.forwarding = v