summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go10
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go4
-rw-r--r--pkg/tcpip/stack/iptables.go75
-rw-r--r--pkg/tcpip/stack/iptables_targets.go102
-rw-r--r--pkg/tcpip/stack/iptables_types.go41
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go21
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go1256
-rw-r--r--pkg/tcpip/stack/nic.go51
-rw-r--r--pkg/tcpip/stack/nic_test.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go60
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go36
-rw-r--r--pkg/tcpip/stack/route.go288
-rw-r--r--pkg/tcpip/stack/stack.go308
-rw-r--r--pkg/tcpip/stack/stack_test.go484
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
-rw-r--r--pkg/tcpip/stack/transport_test.go33
18 files changed, 1974 insertions, 877 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 261705575..9478f3fb7 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
+ // Cache the subnet in addrState to avoid calls to addr.Subnet() as that
+ // results in allocations on every call.
+ subnet: addr.Subnet(),
}
a.mu.endpoints[addr.Address] = addrState
addrState.mu.Lock()
@@ -666,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil)
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
-
+ subnet tcpip.Subnet
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
@@ -686,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
+// Subnet implements AddressEndpoint.
+func (a *addressState) Subnet() tcpip.Subnet {
+ return a.subnet
+}
+
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 0cd1da11f..9a17efcba 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.Addr
- replyTID.srcPort = rt.Port
+ replyTID.srcAddr = address
+ replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
@@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ length := uint16(len(tcpHeader) + pkt.Data.Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
- } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 380688038..7a501acdc 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 8d6d9a7f1..2d8c883cd 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,30 +22,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// tableID is an index into IPTables.tables.
-type tableID int
+// TableID identifies a specific table.
+type TableID int
+// Each value identifies a specific table.
const (
- natID tableID = iota
- mangleID
- filterID
- numTables
+ NATID TableID = iota
+ MangleID
+ FilterID
+ NumTables
)
-// Table names.
-const (
- NATTable = "nat"
- MangleTable = "mangle"
- FilterTable = "filter"
-)
-
-// nameToID is immutable.
-var nameToID = map[string]tableID{
- NATTable: natID,
- MangleTable: mangleID,
- FilterTable: filterID,
-}
-
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- v4Tables: [numTables]Table{
- natID: Table{
+ v4Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -81,7 +68,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -99,7 +86,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -122,8 +109,8 @@ func DefaultTables() *IPTables {
},
},
},
- v6Tables: [numTables]Table{
- natID: Table{
+ v6Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -146,7 +133,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -164,7 +151,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -187,10 +174,10 @@ func DefaultTables() *IPTables {
},
},
},
- priorities: [NumHooks][]tableID{
- Prerouting: []tableID{mangleID, natID},
- Input: []tableID{natID, filterID},
- Output: []tableID{mangleID, natID, filterID},
+ priorities: [NumHooks][]TableID{
+ Prerouting: []TableID{MangleID, NATID},
+ Input: []TableID{NATID, FilterID},
+ Output: []TableID{MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -229,26 +216,20 @@ func EmptyNATTable() Table {
}
}
-// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- id, ok := nameToID[name]
- if !ok {
- return Table{}, false
- }
+// GetTable returns a table with the given id and IP version. It panics when an
+// invalid id is provided.
+func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
if ipv6 {
- return it.v6Tables[id], true
+ return it.v6Tables[id]
}
- return it.v4Tables[id], true
+ return it.v4Tables[id]
}
-// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- id, ok := nameToID[name]
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
+// ReplaceTable replaces or inserts table by name. It panics when an invalid id
+// is provided.
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
// check the NAT table.
- if tableID == natID && pkt.NatDone {
+ if tableID == NATID && pkt.NatDone {
continue
}
var table Table
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 538c4625d..d63e9757c 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -26,13 +28,6 @@ type AcceptTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (at *AcceptTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: at.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
@@ -44,22 +39,11 @@ type DropTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (dt *DropTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: dt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
-// ErrorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const ErrorTargetName = "ERROR"
-
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
type ErrorTarget struct {
@@ -67,14 +51,6 @@ type ErrorTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (et *ErrorTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: et.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
@@ -90,14 +66,6 @@ type UserChainTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (uc *UserChainTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: uc.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
@@ -110,50 +78,39 @@ type ReturnTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *ReturnTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
-// RedirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const RedirectTargetName = "REDIRECT"
-
-// RedirectTarget redirects the packet by modifying the destination port/IP.
+// RedirectTarget redirects the packet to this machine by modifying the
+// destination port/IP. Outgoing packets are redirected to the loopback device,
+// and incoming packets are redirected to the incoming interface (rather than
+// forwarded).
+//
// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
// them.
type RedirectTarget struct {
- // Addr indicates address used to redirect.
- Addr tcpip.Address
-
- // Port indicates port used to redirect.
+ // Port indicates port used to redirect. It is immutable.
Port uint16
- // NetworkProtocol is the network protocol the target is used with.
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *RedirectTarget) ID() TargetID {
- return TargetID{
- Name: RedirectTargetName,
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ address = tcpip.Address([]byte{127, 0, 0, 1})
} else {
- rt.Addr = header.IPv6Loopback
+ address = header.IPv6Loopback
}
case Prerouting:
- rt.Addr = address
+ // No-op, as address is already set correctly.
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
+ netHeader := pkt.Network()
+ netHeader.SetDestinationAddress(address)
// Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(protocol, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udpHeader.SetChecksum(0)
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- pkt.Network().SetDestinationAddress(rt.Addr)
-
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 7b3f3e88b..4b86c1be9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -37,7 +37,6 @@ import (
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
-// These values correspond to values in include/uapi/linux/netfilter.h.
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
@@ -86,8 +85,8 @@ type IPTables struct {
mu sync.RWMutex
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables. mu must be locked for accessing.
- v4Tables [numTables]Table
- v6Tables [numTables]Table
+ v4Tables [NumTables]Table
+ v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
@@ -96,7 +95,7 @@ type IPTables struct {
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
- priorities [NumHooks][]tableID
+ priorities [NumHooks][]TableID
connections ConnTrack
@@ -104,6 +103,24 @@ type IPTables struct {
reaperDone chan struct{}
}
+// VisitTargets traverses all the targets of all tables and replaces each with
+// transform(target).
+func (it *IPTables) VisitTargets(transform func(Target) Target) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ for tid := range it.v4Tables {
+ for i, rule := range it.v4Tables[tid].Rules {
+ it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+ for tid := range it.v6Tables {
+ for i, rule := range it.v6Tables[tid].Rules {
+ it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+}
+
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
@@ -169,7 +186,6 @@ type IPHeaderFilter struct {
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
- // TODO(gvisor.dev/issue/3549): Check this field during matching.
CheckProtocol bool
// Dst matches the destination IP address.
@@ -309,23 +325,8 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
-// A TargetID uniquely identifies a target.
-type TargetID struct {
- // Name is the target name as stored in the xt_entry_target struct.
- Name string
-
- // NetworkProtocol is the protocol to which the target applies.
- NetworkProtocol tcpip.NetworkProtocolNumber
-
- // Revision is the version of the target.
- Revision uint8
-}
-
// A Target is the interface for taking an action for a packet.
type Target interface {
- // ID uniquely identifies the Target.
- ID() TargetID
-
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index aec77610d..493e48031 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -17,12 +17,19 @@ package stack
import (
"fmt"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+const (
+ // immediateDuration is a duration of zero for scheduling work that needs to
+ // be done immediately but asynchronously to avoid deadlock.
+ immediateDuration time.Duration = 0
+)
+
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
Addr tcpip.Address
@@ -242,7 +249,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(config.RetransmitTimer)
}
- sendUnicastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(immediateDuration)
case Failed:
e.notifyWakersLocked()
@@ -324,7 +336,12 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.job.Schedule(config.RetransmitTimer)
}
- sendMulticastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(immediateDuration)
case Stale:
e.setStateLocked(Delay)
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index d297c9422..c2b763325 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -47,6 +47,12 @@ const (
entryTestNetDefaultMTU = 65536
)
+// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
+ clock.Advance(immediateDuration)
+}
+
// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
// to predict the time that an event will be dispatched.
@@ -308,7 +314,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
func TestEntryUnknownToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
@@ -317,6 +323,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -354,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
func TestEntryUnknownToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
@@ -364,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Unlock()
// No probes should have been sent.
+ runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
linkRes.mu.Unlock()
@@ -488,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
func TestEntryIncompleteToReachable(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -519,6 +520,17 @@ func TestEntryIncompleteToReachable(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -552,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
// to Reachable.
func TestEntryAddsAndClearsWakers(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
w := sleep.Waker{}
s := sleep.Sleeper{}
@@ -561,6 +573,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -584,20 +614,6 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -627,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -660,6 +666,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
}
linkRes.mu.Unlock()
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -689,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
func TestEntryIncompleteToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -720,6 +733,17 @@ func TestEntryIncompleteToStale(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -830,12 +854,30 @@ func (*testLocker) Unlock() {}
func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -861,20 +903,6 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -910,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -945,6 +959,24 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -983,16 +1015,9 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1007,6 +1032,17 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
@@ -1053,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1085,6 +1110,21 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1119,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1165,6 +1184,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1199,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1245,6 +1262,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1279,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1324,6 +1340,24 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1353,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1392,6 +1408,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1430,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1455,20 +1511,6 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1507,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1546,6 +1570,28 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1584,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1619,6 +1651,24 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1657,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
func TestEntryStaleToDelay(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1689,6 +1728,21 @@ func TestEntryStaleToDelay(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1736,21 +1790,9 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleUpperLevelConfirmationLocked()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1765,8 +1807,23 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1833,28 +1890,9 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1869,8 +1907,30 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1937,6 +1997,24 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1959,22 +2037,7 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2031,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2071,6 +2115,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2109,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2142,6 +2197,22 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2189,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2226,6 +2281,26 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2277,6 +2352,27 @@ func TestEntryDelayToProbe(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2289,25 +2385,19 @@ func TestEntryDelayToProbe(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
@@ -2367,6 +2457,27 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2376,25 +2487,19 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2459,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
@@ -2473,6 +2572,27 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2482,25 +2602,19 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2569,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
@@ -2583,6 +2691,27 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2592,25 +2721,20 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2696,9 +2820,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
wantProbes := []entryTestProbeInfo{
- // Probe caused by the Delay-to-Probe transition
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
@@ -2729,7 +2851,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2795,6 +2916,27 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2804,25 +2946,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2843,7 +2979,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2918,6 +3053,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2927,25 +3083,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2963,7 +3113,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -3038,6 +3187,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -3047,25 +3217,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -3083,7 +3247,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -3158,9 +3321,9 @@ func TestEntryProbeToFailed(t *testing.T) {
e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
{
wantProbes := []entryTestProbeInfo{
- // Caused by the Unknown-to-Incomplete transition.
{
RemoteAddress: entryTestAddr1,
LocalAddress: entryTestAddr2,
@@ -3283,6 +3446,26 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -3293,33 +3476,28 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
clock.Advance(waitFor)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The next three probe are sent in Probe.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 17f2e6b46..60c81a3aa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -348,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
+func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ return true
+ }
+
+ return false
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
@@ -555,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
}
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
defer r.Release()
- r.RemoteLinkAddress = remotelinkAddr
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -594,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if local == "" {
local = n.LinkEndpoint.LinkAddress()
}
+ pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
@@ -669,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.nic
+ n := r.localAddressNIC
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
- r.RemoteLinkAddress = remote
+ pkt.NICID = n.ID()
r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -735,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -747,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, pkt)
+ n.stack.demux.deliverRawPacket(protocol, pkt)
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
@@ -776,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
return TransportPacketHandled
}
- id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers()))
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+ id := TransportEndpointID{
+ LocalPort: dstPort,
+ LocalAddress: dst,
+ RemotePort: srcPort,
+ RemoteAddress: src,
+ }
+ if n.stack.demux.deliverPacket(protocol, pkt, id) {
return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, pkt) {
+ if state.defaultHandler(id, pkt) {
return TransportPacketHandled
}
}
@@ -791,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet so
// give the protocol specific error handler a chance to handle it.
// If it doesn't handle it then we should do so.
- switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
return TransportPacketHandled
@@ -895,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 4af04846f..5b5c58afb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
-}
+func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
// Close implements NetworkEndpoint.Close.
func (e *testIPv6Endpoint) Close() {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7f54a6de8..664cc6fa0 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -112,6 +112,16 @@ type PacketBuffer struct {
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
+
+ // NICID is the ID of the interface the network packet was received at.
+ NICID tcpip.NICID
+
+ // RXTransportChecksumValidated indicates that transport checksum verification
+ // may be safely skipped.
+ RXTransportChecksumValidated bool
+
+ // NetworkPacketInfo holds an incoming packet's network-layer information.
+ NetworkPacketInfo NetworkPacketInfo
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// Clone should be called in such cases so that no modifications is done to
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
- TransportProtocolNumber: pk.TransportProtocolNumber,
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
+ PktType: pk.PktType,
+ NICID: pk.NICID,
+ RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
+ NetworkPacketInfo: pk.NetworkPacketInfo,
}
- return newPk
+}
+
+// SourceLinkAddress returns the source link address of the packet.
+func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
+ link := pk.LinkHeader().View()
+
+ if link.IsEmpty() {
+ return ""
+ }
+
+ return header.Ethernet(link).SourceAddress()
}
// Network returns the network header as a header.Network.
@@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
+// CloneToInbound makes a shallow copy of the packet buffer to be used as an
+// inbound packet.
+//
+// See PacketBuffer.Data for details about how a packet buffer holds an inbound
+// packet.
+func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
+ return NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
+ })
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index f838eda8d..5d364a2b0 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
} else if _, err := p.route.Resolve(nil); err != nil {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 203f3b51f..b8f333057 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -63,17 +63,28 @@ const (
ControlUnknown
)
+// NetworkPacketInfo holds information about a network layer packet.
+type NetworkPacketInfo struct {
+ // RemoteAddressBroadcast is true if the packet's remote address is a
+ // broadcast address.
+ RemoteAddressBroadcast bool
+
+ // LocalAddressBroadcast is true if the packet's local address is a broadcast
+ // address.
+ LocalAddressBroadcast bool
+}
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// UniqueID returns an unique ID for this transport endpoint.
UniqueID() uint64
- // HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint. It sets pkt.TransportHeader.
+ // HandlePacket is called by the stack when new packets arrive to this
+ // transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
@@ -105,8 +116,8 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(*PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -172,9 +183,9 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // HandleUnknownDestinationPacket takes ownership of the packet if it handles
// the issue.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
+ HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -227,8 +238,8 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
+ // DeliverTransportPacket takes ownership of the packet.
+ DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface {
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
+ // Subnet returns the subnet of the endpoint's address.
+ Subnet() tcpip.Subnet
+
// IsAssigned returns whether or not the endpoint is considered bound
// to its NetworkEndpoint.
IsAssigned(allowExpired bool) bool
@@ -547,7 +561,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..15ff437c7 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,11 +47,16 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // nic is the NIC the route goes through.
- nic *NIC
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
- // addressEndpoint is the local address this route is associated with.
- addressEndpoint AssignableAddressEndpoint
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
// linkCache is set if link address resolution is enabled for this protocol on
// the route's NIC.
@@ -60,51 +67,144 @@ type Route struct {
linkRes LinkAddressResolver
}
+// constructAndValidateRoute validates and initializes a route. It takes
+// ownership of the provided local address.
+//
+// Returns an empty route if validation fails.
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+ addrWithPrefix := addressEndpoint.AddressWithPrefix()
+
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ addressEndpoint.DecRef()
+ return Route{}
+ }
+
+ // If no remote address is provided, use the local address.
+ if len(remoteAddr) == 0 {
+ remoteAddr = addrWithPrefix.Address
+ }
+
+ r := makeRoute(
+ netProto,
+ addrWithPrefix.Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ addressEndpoint,
+ handleLocal,
+ multicastLoop,
+ )
+
+ // If the route requires us to send a packet through some gateway, do not
+ // broadcast it.
+ if len(gateway) > 0 {
+ r.NextHop = gateway
+ } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ }
+
+ return r
+}
+
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+ if localAddressNIC.stack != outgoingNIC.stack {
+ panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
+ }
+
loop := PacketOut
- if handleLocal && localAddr != "" && remoteAddr == localAddr {
- loop = PacketLoop
- } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
- loop |= PacketLoop
- } else if remoteAddr == header.IPv4Broadcast {
- loop |= PacketLoop
+
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if !outgoingNIC.IsLoopback() {
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ loop |= PacketLoop
+ }
}
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- addressEndpoint: addressEndpoint,
- nic: nic,
- Loop: loop,
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ localAddressEndpoint: localAddressEndpoint,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = r.outgoingNIC.stack
}
}
return r
}
+// makeLocalRoute initializes a new local route. It takes ownership of the
+// provided AssignableAddressEndpoint.
+//
+// A local route is a route to a destination that is local to the stack.
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+ loop := PacketLoop
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if outgoingNIC.IsLoopback() {
+ loop = PacketOut
+ }
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+// PopulatePacketInfo populates a packet buffer's packet information fields.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
+ if r.local() {
+ pkt.RXTransportChecksumValidated = true
+ }
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+}
+
+// networkPacketInfo returns the network packet information of the route.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) networkPacketInfo() NetworkPacketInfo {
+ return NetworkPacketInfo{
+ RemoteAddressBroadcast: r.IsOutboundBroadcast(),
+ LocalAddressBroadcast: r.isInboundBroadcast(),
+ }
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.nic.ID()
+ return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.nic.stack.Stats()
+ return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
-// Capabilities returns the link-layer capabilities of the route.
-func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+// RequiresTXTransportChecksum returns false if the route does not require
+// transport checksums to be populated.
+func (r *Route) RequiresTXTransportChecksum() bool {
+ if r.local() {
+ return false
+ }
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
+}
+
+// HasSoftwareGSOCapability returns true if the route supports software GSO.
+func (r *Route) HasSoftwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+}
+
+// HasHardwareGSOCapability returns true if the route supports hardware GSO.
+func (r *Route) HasHardwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+}
+
+// HasSaveRestoreCapability returns true if the route supports save/restore.
+func (r *Route) HasSaveRestoreCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
+}
+
+// HasDisconncetOkCapability returns true if the route supports disconnecting.
+func (r *Route) HasDisconncetOkCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ // If specified, the local address used for link address resolution must be an
+ // address on the outgoing interface.
+ var linkAddressResolutionRequestLocalAddr tcpip.Address
+ if r.localAddressNIC == r.outgoingNIC {
+ linkAddressResolutionRequestLocalAddr = r.LocalAddress
+ }
+
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
neigh.removeWaker(nextAddr, waker)
return
}
- r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
+}
+
+// local returns true if the route is a local route.
+func (r *Route) local() bool {
+ return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before the route can be written to.
//
-// The NIC r uses must not be locked.
+// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.nic.neigh != nil {
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ return false
}
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
+
+ return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil
+}
+
+func (r *Route) isValidForOutgoing() bool {
+ if !r.outgoingNIC.Enabled() {
+ return false
+ }
+
+ if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ return false
+ }
+
+ // If the source NIC and outgoing NIC are different, make sure the stack has
+ // forwarding enabled, or the packet will be handled locally.
+ if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) {
+ return false
+ }
+
+ return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.nic.getNetworkEndpoint(r.NetProto).MTU()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.addressEndpoint != nil {
- r.addressEndpoint.DecRef()
- r.addressEndpoint = nil
+ if r.localAddressEndpoint != nil {
+ r.localAddressEndpoint.DecRef()
+ r.localAddressEndpoint = nil
}
}
// Clone clones the route.
func (r *Route) Clone() Route {
- if r.addressEndpoint != nil {
- _ = r.addressEndpoint.IncRef()
+ if r.localAddressEndpoint != nil {
+ if !r.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
+ }
}
return *r
}
@@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.nic.stack
+ return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := r.localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool {
return r.isV4Broadcast(r.RemoteAddress)
}
-// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// isInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
-func (r *Route) IsInboundBroadcast() bool {
+func (r *Route) isInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.LocalAddress)
}
@@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool {
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- addressEndpoint: r.addressEndpoint,
- nic: r.nic,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ localAddressEndpoint: r.localAddressEndpoint,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e8f1c110e..a23fb97ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -518,6 +519,10 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // IPTables are the initial iptables rules. If nil, iptables will allow
+ // all traffic.
+ IPTables *IPTables
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -620,6 +625,10 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
+ if opts.IPTables == nil {
+ opts.IPTables = DefaultTables()
+ }
+
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
@@ -633,7 +642,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
+ tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
@@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1194,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return Route{}, false
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return Route{}, false
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return Route{}, false
+ }
+
+ return r, true
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
+ return r, true
+ }
+ }
+
+ return Route{}, false
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return Route{}, false
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ addressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
+ if r == (Route{}) {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
+
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ return r, nil
}
+ }
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+ }
+
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
return r, nil
}
}
}
}
- if !needRoute {
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return Route{}, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1457,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1910,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 4eed4ced4..dedfdd435 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
@@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
+ netHdr := pkt.NetworkHeader().View()
+ f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ pkt.NetworkProtocolNumber = fakeNetNumber
hdr[dstAddrOffset] = r.RemoteAddress[0]
hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- f.HandlePacket(r, pkt)
+ pkt := pkt.Clone()
+ r.PopulatePacketInfo(pkt)
+ f.HandlePacket(pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
if !ok {
return 0, false, false
}
+ pkt.NetworkProtocolNumber = fakeNetNumber
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) {
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
+// TestExternalSendWithHandleLocal tests that the stack creates a non-local
+// route when spoofing or promiscuous mode are enabled.
+//
+// This test makes sure that packets are transmitted from the stack.
+func TestExternalSendWithHandleLocal(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+
+ localAddr = tcpip.Address("\x01")
+ dstAddr = tcpip.Address("\x03")
+ )
+
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ configureStack func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Default",
+ configureStack: func(*testing.T, *stack.Stack) {},
+ },
+ {
+ name: "Spoofing",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ {
+ name: "Promiscuous",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, handleLocal := range []bool{true, false} {
+ t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ HandleLocal: handleLocal,
+ })
+
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
+
+ test.configureStack(t, s)
+
+ r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
+ }
+ defer r.Release()
+
+ if r.LocalAddress != localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+
+ if n := ep.Drain(); n != 0 {
+ t.Fatalf("got ep.Drain() = %d, want = 0", n)
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: fakeTransNumber,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.NewView(10).ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, _, _): %s", err)
+ }
+ if n := ep.Drain(); n != 1 {
+ t.Fatalf("got ep.Drain() = %d, want = 1", n)
+ }
+ })
+ }
+ })
+ }
+}
+
func TestSpoofingWithAddress(t *testing.T) {
localAddr := tcpip.Address("\x01")
nonExistentLocalAddr := tcpip.Address("\x02")
@@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
RemoteAddress: ipv4SubnetBcast,
RemoteLinkAddress: header.EthernetBroadcastAddress,
NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
+ Loop: stack.PacketOut | stack.PacketLoop,
},
},
// Broadcast to a locally attached /31 subnet does not populate the
@@ -3756,3 +3862,369 @@ func TestRemoveRoutes(t *testing.T) {
}
}
}
+
+func TestFindRouteWithForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Addr = tcpip.Address("\x01")
+ nic2Addr = tcpip.Address("\x02")
+ remoteAddr = tcpip.Address("\x03")
+ )
+
+ type netCfg struct {
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1Addr tcpip.Address
+ nic2Addr tcpip.Address
+ remoteAddr tcpip.Address
+ }
+
+ fakeNetCfg := netCfg{
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1Addr: nic1Addr,
+ nic2Addr: nic2Addr,
+ remoteAddr: remoteAddr,
+ }
+
+ globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
+ globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
+
+ ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: llAddr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: globalIPv6Addr1,
+ }
+ ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: llAddr1,
+ remoteAddr: llAddr2,
+ }
+ ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }
+
+ tests := []struct {
+ name string
+
+ netCfg netCfg
+ forwardingEnabled bool
+
+ addrNIC tcpip.NICID
+ localAddr tcpip.Address
+
+ findRouteErr *tcpip.Error
+ dependentOnForwarding bool
+ }{
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ }
+
+ if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ }
+
+ if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
+
+ r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if err != test.findRouteErr {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ }
+ defer r.Release()
+
+ if test.findRouteErr != nil {
+ return
+ }
+
+ if r.LocalAddress != test.localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr)
+ }
+ if r.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Sending a packet should always go through NIC2 since we only install a
+ // route to test.netCfg.remoteAddr through NIC2.
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := send(r, data); err != nil {
+ t.Fatalf("send(_, _): %s", err)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not sent through ep2")
+ }
+ if pkt.Route.LocalAddress != test.localAddr {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ }
+ if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if !test.forwardingEnabled || !test.dependentOnForwarding {
+ return
+ }
+
+ // Disabling forwarding when the route is dependent on forwarding being
+ // enabled should make the route invalid.
+ if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ }
+ if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ if n := ep2.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep2", n)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 6b8071467..c457b67a2 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
+ if f.acceptQueue == nil {
+ return
}
+
+ netHdr := pkt.NetworkHeader().View()
+ route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ route.ResolveWith(pkt.SourceLinkAddress())
+
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: route.RemoteAddress,
+ route: route,
+ })
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
@@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}