summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go76
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go82
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go4
-rw-r--r--pkg/tcpip/stack/nic.go29
-rw-r--r--pkg/tcpip/stack/registration.go6
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go145
9 files changed, 216 insertions, 136 deletions
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index eed49f5d2..5123b7d6a 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
}
+ configs.Validate()
+
*d = DAD{
opts: opts,
configs: configs,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 049811cbb..23178277a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -63,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -82,6 +88,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -151,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -852,7 +882,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1144,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1165,12 +1194,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1283,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4051fda07..307e1972d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f0e06f86b..95e11ac51 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -187,6 +193,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -405,20 +417,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
// As per RFC 4291 section 2.8:
//
@@ -1109,7 +1139,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -1932,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1957,12 +1986,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -2137,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index b29fed347..f0ff111c5 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -705,7 +705,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -1710,7 +1710,7 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
return
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8d615500f..dbba2c79f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t
return d.CheckDuplicateAddress(addr, h), nil
}
+
+func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ forwardingEP.SetForwarding(enable)
+ return nil
+}
+
+func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ ep := n.getNetworkEndpoint(protocol)
+ if ep == nil {
+ return false, &tcpip.ErrUnknownProtocol{}
+ }
+
+ forwardingEP, ok := ep.(ForwardingNetworkEndpoint)
+ if !ok {
+ return false, &tcpip.ErrNotSupported{}
+ }
+
+ return forwardingEP.Forwarding(), nil
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a82c807b4..85bb87b4b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface {
IPStats() *tcpip.IPStats
}
-// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
-type ForwardingNetworkProtocol interface {
- NetworkProtocol
+// ForwardingNetworkEndpoint is a network endpoint that may forward packets.
+type ForwardingNetworkEndpoint interface {
+ NetworkEndpoint
// Forwarding returns the forwarding configuration.
Forwarding() bool
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 8a044c073..f17c04277 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool {
// If the source NIC and outgoing NIC are different, make sure the stack has
// forwarding enabled, or the packet will be handled locally.
- if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
+ if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) {
return false
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 483a960c8..8814f45a6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -95,8 +95,9 @@ type Stack struct {
}
}
- mu sync.RWMutex
- nics map[tcpip.NICID]*nic
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*nic
+ defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{}
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -348,22 +349,23 @@ func New(opts Options) *Stack {
}
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: mathrand.New(randSrc),
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: mathrand.New(randSrc),
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -492,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
-// passed protocol and sets the default setting for newly created NICs.
-func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
- protocol, ok := s.networkProtocols[protocolNum]
+// SetNICForwarding enables or disables packet forwarding on the specified NIC
+// for the passed protocol.
+func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrUnknownProtocol{}
+ return &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ return nic.setForwarding(protocol, enable)
+}
+
+// NICForwarding returns the forwarding configuration for the specified NIC.
+func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
if !ok {
- return &tcpip.ErrNotSupported{}
+ return false, &tcpip.ErrUnknownNICID{}
}
- forwardingProtocol.SetForwarding(enable)
- return nil
+ return nic.forwarding(protocol)
}
-// Forwarding returns true if packet forwarding between NICs is enabled for the
-// passed protocol.
-func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
- protocol, ok := s.networkProtocols[protocolNum]
- if !ok {
- return false
+// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the
+// passed protocol and sets the default setting for newly created NICs.
+func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ doneOnce := false
+ for id, nic := range s.nics {
+ if err := nic.setForwarding(protocol, enable); err != nil {
+ // Expect forwarding to be settable on all interfaces if it was set on
+ // one.
+ if doneOnce {
+ panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err))
+ }
+
+ return err
+ }
+
+ doneOnce = true
}
- forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
- if !ok {
- return false
+ if enable {
+ s.defaultForwardingEnabled[protocol] = struct{}{}
+ } else {
+ delete(s.defaultForwardingEnabled, protocol)
}
- return forwardingProtocol.Forwarding()
+ return nil
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
@@ -659,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
+ for proto := range s.defaultForwardingEnabled {
+ if err := n.setForwarding(proto, true); err != nil {
+ panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err))
+ }
+ }
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -786,6 +817,10 @@ type NICInfo struct {
// value sent in haType field of an ARP Request sent by this NIC and the
// value expected in the haType field of an ARP response.
ARPHardwareType header.ARPHardwareType
+
+ // Forwarding holds the forwarding status for each network endpoint that
+ // supports forwarding.
+ Forwarding map[tcpip.NetworkProtocolNumber]bool
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -815,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
netStats[proto] = netEP.Stats()
}
- nics[id] = NICInfo{
+ info := NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
@@ -825,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ Forwarding: make(map[tcpip.NetworkProtocolNumber]bool),
}
+
+ for proto := range s.networkProtocols {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ info.Forwarding[proto] = forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+ }
+
+ nics[id] = info
}
return nics
}
@@ -1029,6 +1080,20 @@ func (s *Stack) HandleLocal() bool {
return s.handleLocal
}
+func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool {
+ switch forwarding, err := nic.forwarding(proto); err.(type) {
+ case nil:
+ return forwarding
+ case *tcpip.ErrUnknownProtocol:
+ panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID()))
+ case *tcpip.ErrNotSupported:
+ // Not all network protocols support forwarding.
+ return false
+ default:
+ panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err))
+ }
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1081,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
+ onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1120,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// requirement to do this from any RFC but simply a choice made to better
// follow a strong host model which the netstack follows at the time of
// writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
+ if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) {
chosenRoute = route
}
}