summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/stack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/stack.go')
-rw-r--r--pkg/tcpip/stack/stack.go163
1 files changed, 111 insertions, 52 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 843118b13..72760a4a7 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,6 +29,7 @@ import (
"time"
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -39,13 +40,6 @@ import (
)
const (
- // ageLimit is set to the same cache stale time used in Linux.
- ageLimit = 1 * time.Minute
- // resolutionTimeout is set to the same ARP timeout used in Linux.
- resolutionTimeout = 1 * time.Second
- // resolutionAttempts is set to the same ARP retries used in Linux.
- resolutionAttempts = 3
-
// DefaultTOS is the default type of service value for network endpoints.
DefaultTOS = 0
)
@@ -65,10 +59,10 @@ type ResumableEndpoint interface {
}
// uniqueIDGenerator is a default unique ID generator.
-type uniqueIDGenerator uint64
+type uniqueIDGenerator atomicbitops.AlignedAtomicUint64
func (u *uniqueIDGenerator) UniqueID() uint64 {
- return atomic.AddUint64((*uint64)(u), 1)
+ return ((*atomicbitops.AlignedAtomicUint64)(u)).Add(1)
}
// Stack is a networking stack, with all supported protocols, NICs, and route
@@ -94,8 +88,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -322,7 +317,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
- clock = &tcpip.StdClock{}
+ clock = tcpip.NewStdClock()
}
if opts.UniqueID == nil {
@@ -347,22 +342,23 @@ func New(opts Options) *Stack {
}
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -491,37 +487,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables packet forwarding between NICs for the
-// passed protocol.
-func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -658,6 +678,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -772,7 +797,7 @@ type NICInfo struct {
// MTU is the maximum transmission unit.
MTU uint32
- Stats NICStats
+ Stats tcpip.NICStats
// NetworkStats holds the stats of each NetworkEndpoint bound to the NIC.
NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats
@@ -785,6 +810,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -814,17 +843,33 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
MTU: nic.LinkEndpoint.MTU(),
- Stats: nic.stats,
+ Stats: nic.stats.local,
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1028,6 +1073,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1080,7 +1139,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1119,7 +1178,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}