summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD45
-rw-r--r--pkg/tcpip/stack/conntrack.go756
-rw-r--r--pkg/tcpip/stack/fake_time_test.go209
-rw-r--r--pkg/tcpip/stack/forwarder.go4
-rw-r--r--pkg/tcpip/stack/forwarder_test.go143
-rw-r--r--pkg/tcpip/stack/iptables.go308
-rw-r--r--pkg/tcpip/stack/iptables_state.go40
-rw-r--r--pkg/tcpip/stack/iptables_targets.go28
-rw-r--r--pkg/tcpip/stack/iptables_types.go129
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go252
-rw-r--r--pkg/tcpip/stack/ndp_test.go274
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go335
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1752
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go482
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2770
-rw-r--r--pkg/tcpip/stack/neighborstate_string.go44
-rw-r--r--pkg/tcpip/stack/nic.go231
-rw-r--r--pkg/tcpip/stack/nic_test.go269
-rw-r--r--pkg/tcpip/stack/nud.go466
-rw-r--r--pkg/tcpip/stack/nud_test.go795
-rw-r--r--pkg/tcpip/stack/packet_buffer.go27
-rw-r--r--pkg/tcpip/stack/registration.go93
-rw-r--r--pkg/tcpip/stack/route.go33
-rw-r--r--pkg/tcpip/stack/stack.go158
-rw-r--r--pkg/tcpip/stack/stack_options.go106
-rw-r--r--pkg/tcpip/stack/stack_test.go405
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go164
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go7
-rw-r--r--pkg/tcpip/stack/transport_test.go48
31 files changed, 9348 insertions, 1029 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f71073207..1c58bed2d 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -16,6 +16,18 @@ go_template_instance(
)
go_template_instance(
+ name = "neighbor_entry_list",
+ out = "neighbor_entry_list.go",
+ package = "stack",
+ prefix = "neighborEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*neighborEntry",
+ "Linker": "*neighborEntry",
+ },
+)
+
+go_template_instance(
name = "packet_buffer_list",
out = "packet_buffer_list.go",
package = "stack",
@@ -27,6 +39,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -35,12 +59,18 @@ go_library(
"forwarder.go",
"icmp_rate_limit.go",
"iptables.go",
+ "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
+ "neighbor_cache.go",
+ "neighbor_entry.go",
+ "neighbor_entry_list.go",
+ "neighborstate_string.go",
"nic.go",
+ "nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
"rand.go",
@@ -48,7 +78,9 @@ go_library(
"route.go",
"stack.go",
"stack_global_state.go",
+ "stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -74,10 +106,12 @@ go_test(
size = "medium",
srcs = [
"ndp_test.go",
+ "nud_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
],
+ shard_count = 20,
deps = [
":stack",
"//pkg/rand",
@@ -89,10 +123,12 @@ go_test(
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -100,8 +136,11 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "fake_time_test.go",
"forwarder_test.go",
"linkaddrcache_test.go",
+ "neighbor_cache_test.go",
+ "neighbor_entry_test.go",
"nic_test.go",
],
library = ":stack",
@@ -110,5 +149,9 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7d1ede1f2..470c265aa 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -20,332 +20,330 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
// Connection tracking is used to track and manipulate packets for NAT rules.
-// The connection is created for a packet if it does not exist. Every connection
-// contains two tuples (original and reply). The tuples are manipulated if there
-// is a matching NAT rule. The packet is modified by looking at the tuples in the
-// Prerouting and Output hooks.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
// Direction of the tuple.
-type ctDirection int
+type direction int
const (
- dirOriginal ctDirection = iota
+ dirOriginal direction = iota
dirReply
)
-// Status of connection.
-// TODO(gvisor.dev/issue/170): Add other states of connection.
-type connStatus int
-
-const (
- connNew connStatus = iota
- connEstablished
-)
-
// Manipulation type for the connection.
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
-// connTrackMutable is the manipulatable part of the tuple.
-type connTrackMutable struct {
- // addr is source address of the tuple.
- addr tcpip.Address
-
- // port is source port of the tuple.
- port uint16
-
- // protocol is network layer protocol.
- protocol tcpip.NetworkProtocolNumber
-}
-
-// connTrackImmutable is the non-manipulatable part of the tuple.
-type connTrackImmutable struct {
- // addr is destination address of the tuple.
- addr tcpip.Address
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+//
+// +stateify savable
+type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
- // direction is direction (original or reply) of the tuple.
- direction ctDirection
+ tupleID
- // port is destination port of the tuple.
- port uint16
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
- // protocol is transport layer protocol.
- protocol tcpip.TransportProtocolNumber
+ // direction is the direction of the tuple.
+ direction direction
}
-// connTrackTuple represents the tuple which is created from the
-// packet.
-type connTrackTuple struct {
- // dst is non-manipulatable part of the tuple.
- dst connTrackImmutable
-
- // src is manipulatable part of the tuple.
- src connTrackMutable
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
}
-// connTrackTupleHolder is the container of tuple and connection.
-type ConnTrackTupleHolder struct {
- // conn is pointer to the connection tracking entry.
- conn *connTrack
-
- // tuple is original or reply tuple.
- tuple connTrackTuple
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
}
-// connTrack is the connection.
-type connTrack struct {
- // originalTupleHolder contains tuple in original direction.
- originalTupleHolder ConnTrackTupleHolder
-
- // replyTupleHolder contains tuple in reply direction.
- replyTupleHolder ConnTrackTupleHolder
-
- // status indicates connection is new or established.
- status connStatus
+// conn is a tracked connection.
+//
+// +stateify savable
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
- // timeout indicates the time connection should be active.
- timeout time.Duration
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
- // manip indicates if the packet should be manipulated.
+ // manip indicates if the packet should be manipulated. It is immutable.
manip manipType
- // tcb is TCB control block. It is used to keep track of states
- // of tcp connection.
- tcb tcpconntrack.TCB
-
// tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb.
+ // update the state of tcb. It is immutable.
tcbHook Hook
-}
-// ConnTrackTable contains a map of all existing connections created for
-// NAT rules.
-type ConnTrackTable struct {
- // connMu protects connTrackTable.
- connMu sync.RWMutex
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
- // connTrackTable maintains a map of tuples needed for connection tracking
- // for iptables NAT rules. The key for the map is an integer calculated
- // using seed, source address, destination address, source port and
- // destination port.
- CtMap map[uint32]ConnTrackTupleHolder
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
- // seed is a one-time random value initialized at stack startup
- // and is used in calculation of hash key for connection tracking
- // table.
- Seed uint32
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
-// parseHeaders sets headers in the packet.
-func parseHeaders(pkt *PacketBuffer) {
- newPkt := pkt.Clone()
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
+type ConnTrack struct {
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // Set network header.
- hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- netHeader := header.IPv4(hdr)
- newPkt.NetworkHeader = hdr
- length := int(netHeader.HeaderLength())
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- // Set transport header.
- switch protocol := netHeader.TransportProtocol(); protocol {
- case header.UDPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
- }
- case header.TCPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
- }
- }
- pkt.NetworkHeader = newPkt.NetworkHeader
- pkt.TransportHeader = newPkt.TransportHeader
+ // buckets is protected by mu.
+ buckets []bucket
}
-// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
- var tuple connTrackTuple
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
+}
- netHeader := header.IPv4(pkt.NetworkHeader)
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tuple.src.addr = netHeader.SourceAddress()
- tuple.src.port = tcpHeader.SourcePort()
- tuple.src.protocol = header.IPv4ProtocolNumber
-
- tuple.dst.addr = netHeader.DestinationAddress()
- tuple.dst.port = tcpHeader.DestinationPort()
- tuple.dst.protocol = netHeader.TransportProtocol()
-
- return tuple, nil
-}
-
-// getReplyTuple creates reply tuple for the given tuple.
-func getReplyTuple(tuple connTrackTuple) connTrackTuple {
- var replyTuple connTrackTuple
- replyTuple.src.addr = tuple.dst.addr
- replyTuple.src.port = tuple.dst.port
- replyTuple.src.protocol = tuple.src.protocol
- replyTuple.dst.addr = tuple.src.addr
- replyTuple.dst.port = tuple.src.port
- replyTuple.dst.protocol = tuple.dst.protocol
- replyTuple.dst.direction = dirReply
-
- return replyTuple
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
}
-// makeNewConn creates new connection.
-func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
- var conn connTrack
- conn.status = connNew
- conn.originalTupleHolder.tuple = tuple
- conn.originalTupleHolder.conn = &conn
- conn.replyTupleHolder.tuple = replyTuple
- conn.replyTupleHolder.conn = &conn
-
- return conn
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
}
-// getTupleHash returns hash of the tuple. The fields used for
-// generating hash are seed (generated once for stack), source address,
-// destination address, source port and destination ports.
-func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
- h := jenkins.Sum32(ct.Seed)
- h.Write([]byte(tuple.src.addr))
- h.Write([]byte(tuple.dst.addr))
- portBuf := make([]byte, 2)
- binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
- h.Write([]byte(portBuf))
- binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
- h.Write([]byte(portBuf))
-
- return h.Sum32()
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
+ return ct.connForTID(tid)
}
-// connTrackForPacket returns connTrack for packet.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
-// transport protocols.
-func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- if hook == Prerouting {
- // Headers will not be set in Prerouting.
- // TODO(gvisor.dev/issue/170): Change this after parsing headers
- // code is added.
- parseHeaders(pkt)
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- var dir ctDirection
- tuple, err := packetToTuple(*pkt, hook)
- if err != nil {
- return nil, dir
- }
-
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
-
- connTrackTable := ct.CtMap
- hash := ct.getTupleHash(tuple)
-
- var conn *connTrack
- switch createConn {
- case true:
- // If connection does not exist for the hash, create a new
- // connection.
- replyTuple := getReplyTuple(tuple)
- replyHash := ct.getTupleHash(replyTuple)
- newConn := makeNewConn(tuple, replyTuple)
- conn = &newConn
-
- // Add tupleHolders to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
- ct.CtMap[hash] = conn.originalTupleHolder
- ct.CtMap[replyHash] = conn.replyTupleHolder
- default:
- tupleHolder, ok := connTrackTable[hash]
- if !ok {
- return nil, dir
- }
+ return nil, dirOriginal
+}
- // If this is the reply of new connection, set the connection
- // status as ESTABLISHED.
- conn = tupleHolder.conn
- if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
- conn.status = connEstablished
- }
- if tupleHolder.conn == nil {
- panic("tupleHolder has null connection tracking entry")
- }
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
- dir = tupleHolder.tuple.dst.direction
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
}
- return conn, dir
+ conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
}
-// SetNatInfo will manipulate the tuples according to iptables NAT rules.
-func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
- // Get the connection. Connection is always created before this
- // function is called.
- conn, _ := ct.connTrackForPacket(pkt, hook, false)
- if conn == nil {
- panic("connection should be created to manipulate tuples.")
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
}
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
- // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
- // support changing of address for Prerouting.
-
- // Change the port as per the iptables rule. This tuple will be used
- // to manipulate the packet in HandlePacket.
- conn.replyTupleHolder.tuple.src.addr = rt.MinIP
- conn.replyTupleHolder.tuple.src.port = rt.MinPort
- newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
- // Add the changed tuple to the map.
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
- ct.CtMap[newHash] = conn.replyTupleHolder
- if hook == Output {
- conn.replyTupleHolder.conn.manip = manipDstOutput
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
}
- // Delete the old tuple.
- delete(ct.CtMap, replyHash)
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
-func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -354,21 +352,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection)
// modified.
switch dir {
case dirOriginal:
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -377,13 +385,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
@@ -402,33 +410,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
-// HandlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
}
- conn, dir := ct.connTrackForPacket(pkt, hook, false)
- // Connection or Rule not found for the packet.
- if conn == nil {
- return
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- // TODO(gvisor.dev/issue/170): Need to support for other transport
- // protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return
+ conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return true
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return
+ return false
}
switch hook {
@@ -442,39 +449,184 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
// other tcp states.
- var st tcpconntrack.Result
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
- }
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
}
- // Delete conntrack if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConnTrack(conn)
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
}
-}
-// deleteConnTrack deletes the connection.
-func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
- if conn == nil {
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
- tuple := conn.originalTupleHolder.tuple
- hash := ct.getTupleHash(tuple)
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
+
+ return true
+}
+
+func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: header.IPv4ProtocolNumber,
+ }
+ conn, _ := ct.connForTID(tid)
+ if conn == nil {
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
+ }
- delete(ct.CtMap, hash)
- delete(ct.CtMap, replyHash)
+ return conn.original.dstAddr, conn.original.dstPort, nil
}
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go
new file mode 100644
index 000000000..92c8cb534
--- /dev/null
+++ b/pkg/tcpip/stack/fake_time_test.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "container/heap"
+ "sync"
+ "time"
+
+ "github.com/dpjacques/clockwork"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type fakeClock struct {
+ clock clockwork.FakeClock
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // times is min-heap of times. A heap is used for quick retrieval of the next
+ // upcoming time of scheduled work.
+ times *timeHeap
+
+ // waitGroups stores one WaitGroup for all work scheduled to execute at the
+ // same time via AfterFunc. This allows parallel execution of all functions
+ // passed to AfterFunc scheduled for the same time.
+ waitGroups map[time.Time]*sync.WaitGroup
+}
+
+func newFakeClock() *fakeClock {
+ return &fakeClock{
+ clock: clockwork.NewFakeClock(),
+ times: &timeHeap{},
+ waitGroups: make(map[time.Time]*sync.WaitGroup),
+ }
+}
+
+var _ tcpip.Clock = (*fakeClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (fc *fakeClock) NowNanoseconds() int64 {
+ return fc.clock.Now().UnixNano()
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (fc *fakeClock) NowMonotonic() int64 {
+ return fc.NowNanoseconds()
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := fc.clock.Now().Add(d)
+ wg := fc.addWait(until)
+ return &fakeTimer{
+ clock: fc,
+ until: until,
+ timer: fc.clock.AfterFunc(d, func() {
+ defer wg.Done()
+ f()
+ }),
+ }
+}
+
+// addWait adds an additional wait to the WaitGroup for parallel execution of
+// all work scheduled for t. Returns a reference to the WaitGroup modified.
+func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
+ fc.mu.RLock()
+ wg, ok := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ if ok {
+ wg.Add(1)
+ return wg
+ }
+
+ fc.mu.Lock()
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+
+ wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ fc.mu.Lock()
+ fc.waitGroups[t] = wg
+ fc.mu.Unlock()
+
+ return wg
+}
+
+// removeWait removes a wait from the WaitGroup for parallel execution of all
+// work scheduled for t.
+func (fc *fakeClock) removeWait(t time.Time) {
+ fc.mu.RLock()
+ defer fc.mu.RUnlock()
+
+ wg := fc.waitGroups[t]
+ wg.Done()
+}
+
+// advance executes all work that have been scheduled to execute within d from
+// the current fake time. Blocks until all work has completed execution.
+func (fc *fakeClock) advance(d time.Duration) {
+ // Block until all the work is done
+ until := fc.clock.Now().Add(d)
+ for {
+ fc.mu.Lock()
+ if fc.times.Len() == 0 {
+ fc.mu.Unlock()
+ return
+ }
+
+ t := heap.Pop(fc.times).(time.Time)
+ if t.After(until) {
+ // No work to do
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+ return
+ }
+ fc.mu.Unlock()
+
+ diff := t.Sub(fc.clock.Now())
+ fc.clock.Advance(diff)
+
+ fc.mu.RLock()
+ wg := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ wg.Wait()
+
+ fc.mu.Lock()
+ delete(fc.waitGroups, t)
+ fc.mu.Unlock()
+ }
+}
+
+type fakeTimer struct {
+ clock *fakeClock
+ timer clockwork.Timer
+
+ mu sync.RWMutex
+ until time.Time
+}
+
+var _ tcpip.Timer = (*fakeTimer)(nil)
+
+// Reset implements tcpip.Timer.Reset.
+func (ft *fakeTimer) Reset(d time.Duration) {
+ if !ft.timer.Reset(d) {
+ return
+ }
+
+ ft.mu.Lock()
+ defer ft.mu.Unlock()
+
+ ft.clock.removeWait(ft.until)
+ ft.until = ft.clock.clock.Now().Add(d)
+ ft.clock.addWait(ft.until)
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (ft *fakeTimer) Stop() bool {
+ if !ft.timer.Stop() {
+ return false
+ }
+
+ ft.mu.RLock()
+ defer ft.mu.RUnlock()
+
+ ft.clock.removeWait(ft.until)
+ return true
+}
+
+type timeHeap []time.Time
+
+var _ heap.Interface = (*timeHeap)(nil)
+
+func (h timeHeap) Len() int {
+ return len(h)
+}
+
+func (h timeHeap) Less(i, j int) bool {
+ return h[i].Before(h[j])
+}
+
+func (h timeHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *timeHeap) Push(x interface{}) {
+ *h = append(*h, x.(time.Time))
+}
+
+func (h *timeHeap) Pop() interface{} {
+ last := (*h)[len(*h)-1]
+ *h = (*h)[:len(*h)-1]
+ return last
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
index 6b64cd37f..3eff141e6 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/forwarder.go
@@ -32,7 +32,7 @@ type pendingPacket struct {
nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
- pkt PacketBuffer
+ pkt *PacketBuffer
}
type forwardQueue struct {
@@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue {
return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
shouldWait := false
f.Lock()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 8084d50bc..c962693f5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -33,6 +34,10 @@ const (
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -68,16 +73,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
return &f.id
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
-
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -96,13 +94,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
@@ -112,7 +110,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -123,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {}
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -140,7 +140,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = netHeader
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+ return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -166,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
})
}
return nil
@@ -190,7 +200,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
LocalLinkAddress tcpip.LinkAddress
- Pkt PacketBuffer
+ Pkt *PacketBuffer
}
type fwdTestLinkEndpoint struct {
@@ -203,13 +213,13 @@ type fwdTestLinkEndpoint struct {
}
// InjectInbound injects an inbound packet.
-func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -251,7 +261,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -270,7 +280,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw
func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, gso, protocol, *pkt)
+ e.WritePacket(r, gso, protocol, pkt)
n++
}
@@ -280,7 +290,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: PacketBuffer{Data: vv},
+ Pkt: &PacketBuffer{Data: vv},
}
select {
@@ -294,6 +304,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
@@ -361,8 +381,8 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -387,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -398,8 +418,8 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -429,8 +449,8 @@ func TestForwardingWithNoResolver(t *testing.T) {
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -445,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
@@ -459,16 +479,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
- buf[0] = 4
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -480,9 +500,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -498,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -509,8 +528,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -524,9 +543,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -543,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -554,10 +572,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -571,14 +589,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
}
- // The first 5 packets should not be forwarded so the the
- // sequemnce number should start with 5.
+ seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
+ if !ok {
+ t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
+ }
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
@@ -596,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -609,8 +631,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
- buf[0] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -626,9 +648,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Data.ToView()
- if b[0] < 8 {
- t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 443423b3c..110ba073d 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,40 +16,49 @@ package stack
import (
"fmt"
- "strings"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
- return IPTables{
- Tables: map[string]Table{
- TablenameNat: Table{
+func DefaultTables() *IPTables {
+ return &IPTables{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -57,65 +66,71 @@ func DefaultTables() IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- Priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
- connections: ConnTrackTable{
- CtMap: make(map[uint32]ConnTrackTupleHolder),
- Seed: generateRandUint32(),
+ connections: ConnTrack{
+ seed: generateRandUint32(),
},
+ reaperDone: make(chan struct{}, 1),
}
}
@@ -124,41 +139,61 @@ func DefaultTables() IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
+// GetTable returns a table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.tables[id], true
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
+ it.modified = true
+ it.tables[id] = table
+ return nil
+}
+
// A chainVerdict is what a table decides should be done with a packet.
type chainVerdict int
@@ -180,13 +215,27 @@ const (
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ if !it.modified {
+ return true
+ }
+
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.HandlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.Priorities[hook] {
- table := it.Tables[tablename]
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -215,17 +264,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -249,9 +340,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
return drop, natPkts
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -296,25 +387,14 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
- if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
- }
-
// Check whether the packet matches the IP header filter.
- if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) {
+ if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -322,7 +402,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, *pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
return RuleDrop, 0
}
@@ -336,46 +416,8 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
}
-func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
- // Check the transport protocol.
- if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
- return false
- }
-
- // Check the destination IP.
- dest := hdr.DestinationAddress()
- matches := true
- for i := range filter.Dst {
- if dest[i]&filter.DstMask[i] != filter.Dst[i] {
- matches = false
- break
- }
- }
- if matches == filter.DstInvert {
- return false
- }
-
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(filter.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which begins
- // with the name should be matched.
- ifName := filter.OutputInterface
- matches = true
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- } else {
- matches = nicName == ifName
- }
- return filter.OutputInterfaceInvert != matches
- }
-
- return true
+// OriginalDst returns the original destination of redirected connections. It
+// returns an error if the connection doesn't exist or isn't redirected.
+func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ return it.connections.originalDst(epID)
}
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
new file mode 100644
index 000000000..529e02a07
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
+
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 36cc6275d..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -24,7 +24,7 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -92,17 +92,12 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
}
- // Set network header.
- if hook == Prerouting {
- parseHeaders(pkt)
- }
-
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
return RuleDrop, 0
@@ -155,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
- // Set up conection for matching NAT rule.
- // Only the first packet of the connection comes here.
- // Other packets will be manipulated in connection tracking.
- if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil {
- ct.SetNatInfo(pkt, rt, hook)
- ct.HandlePacket(pkt, hook, gso, r)
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
}
default:
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index fe06007ae..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,7 +15,11 @@
package stack
import (
+ "strings"
+ "sync"
+
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// A Hook specifies one of the hooks built into the network stack.
@@ -74,63 +78,65 @@ const (
)
// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
type IPTables struct {
- // Tables maps table names to tables. User tables have arbitrary names.
- Tables map[string]Table
+ // mu protects tables, priorities, and modified.
+ mu sync.RWMutex
- // Priorities maps each hook to a list of table names. The order of the
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
+
+ // priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
- // hook.
- Priorities map[Hook][]string
+ // hook. mu needs to be locked for accessing.
+ priorities [NumHooks][]tableID
+
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
- connections ConnTrackTable
+ connections ConnTrack
+
+ // reaperDone can be signalled to stop the reaper goroutine.
+ reaperDone chan struct{}
}
// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules with some metadata for entrypoints and such.
+// really just a list of rules.
+//
+// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
-
- // Metadata holds information about the Table that is useful to users
- // of IPTables, but not to the netstack IPTables code itself.
- metadata interface{}
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
-// Metadata returns the metadata object stored in table.
-func (table *Table) Metadata() interface{} {
- return table.metadata
-}
-
-// SetMetadata sets the metadata object stored in table.
-func (table *Table) SetMetadata(metadata interface{}) {
- table.metadata = metadata
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
+//
+// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
@@ -143,6 +149,8 @@ type Rule struct {
}
// IPHeaderFilter holds basic IP filtering data common to every rule.
+//
+// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
@@ -159,6 +167,16 @@ type IPHeaderFilter struct {
// comparison.
DstInvert bool
+ // Src matches the source IP address.
+ Src tcpip.Address
+
+ // SrcMask masks bits of the source IP address when comparing with Src.
+ SrcMask tcpip.Address
+
+ // SrcInvert inverts the meaning of the source IP check, i.e. when true the
+ // filter will match packets that fail the source comparison.
+ SrcInvert bool
+
// OutputInterface matches the name of the outgoing interface for the
// packet.
OutputInterface string
@@ -173,6 +191,55 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
+// match returns whether hdr matches the filter.
+func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the source and destination IPs.
+ if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ return false
+ }
+
+ // Check the output interface.
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ if hook == Output {
+ n := len(fl.OutputInterface)
+ if n == 0 {
+ return true
+ }
+
+ // If the interface name ends with '+', any interface which begins
+ // with the name should be matched.
+ ifName := fl.OutputInterface
+ matches := true
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return fl.OutputInterfaceInvert != matches
+ }
+
+ return true
+}
+
+// filterAddress returns whether addr matches the filter.
+func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
+ matches := true
+ for i := range filterAddr {
+ if addr[i]&mask[i] != filterAddr[i] {
+ matches = false
+ break
+ }
+ }
+ return matches != invert
+}
+
// A Matcher is the interface for matching packets.
type Matcher interface {
// Name returns the name of the Matcher.
@@ -183,7 +250,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
@@ -191,5 +258,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+ Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 526c7d6ff..5174e639c 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -33,12 +33,6 @@ const (
// Default = 1 (from RFC 4862 section 5.1)
defaultDupAddrDetectTransmits = 1
- // defaultRetransmitTimer is the default amount of time to wait between
- // sending NDP Neighbor solicitation messages.
- //
- // Default = 1s (from RFC 4861 section 10).
- defaultRetransmitTimer = time.Second
-
// defaultMaxRtrSolicitations is the default number of Router
// Solicitation messages to send when a NIC becomes enabled.
//
@@ -79,16 +73,6 @@ const (
// Default = true.
defaultAutoGenGlobalAddresses = true
- // minimumRetransmitTimer is the minimum amount of time to wait between
- // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
- // not impose a minimum Retransmit Timer, but we do here to make sure
- // the messages are not sent all at once. We also come to this value
- // because in the RetransmitTimer field of a Router Advertisement, a
- // value of 0 means unspecified, so the smallest valid value is 1.
- // Note, the unit of the RetransmitTimer field in the Router
- // Advertisement is milliseconds.
- minimumRetransmitTimer = time.Millisecond
-
// minimumRtrSolicitationInterval is the minimum amount of time to wait
// between sending Router Solicitation messages. This limit is imposed
// to make sure that Router Solicitation messages are not sent all at
@@ -467,8 +451,17 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- // The timer used to send the next router solicitation message.
- rtrSolicitTimer *time.Timer
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer tcpip.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -494,7 +487,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -506,38 +499,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -552,15 +545,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -642,19 +635,20 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
- ndp.nic.mu.RLock()
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
if done {
// If we reach this point, it means that the DAD timer fired after
// another goroutine already obtained the NIC lock and stopped DAD
// before this function obtained the NIC lock. Simply return here and do
// nothing further.
- ndp.nic.mu.RUnlock()
return
}
@@ -665,15 +659,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
dadDone := remaining == 0
- ndp.nic.mu.RUnlock()
var err *tcpip.Error
if !dadDone {
- err = ndp.sendDADPacket(addr)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
}
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
if done {
// If we reach this point, it means that DAD was stopped after we released
// the NIC's read lock and before we obtained the write lock.
@@ -721,17 +723,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// addr.
//
// addr must be a tentative IPv6 address on ndp's NIC.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
+//
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- // Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
@@ -750,7 +759,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -846,9 +855,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -925,7 +934,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -954,12 +963,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -984,13 +993,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1008,7 +1017,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1057,14 +1066,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1129,7 +1138,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1137,7 +1146,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1159,19 +1168,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1403,7 +1412,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1416,7 +1425,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1429,7 +1438,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1456,9 +1465,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1493,16 +1502,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1521,9 +1530,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1544,8 +1553,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1557,7 +1566,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1571,14 +1580,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1591,17 +1600,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1612,7 +1621,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1692,7 +1701,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1704,8 +1713,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1750,13 +1759,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1816,7 +1825,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicitTimer != nil {
+ if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
return
}
@@ -1833,14 +1842,27 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress)
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
if ref == nil {
- ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
}
+ ndp.nic.mu.Unlock()
+
localAddr := ref.ep.ID().LocalAddress
r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
@@ -1849,6 +1871,13 @@ func (ndp *ndpState) startSolicitingRouters() {
// header.IPv6AllRoutersMulticastAddress is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
} else if c != nil {
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
@@ -1881,7 +1910,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, PacketBuffer{Header: hdr},
+ }, &PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
@@ -1893,17 +1922,18 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
- if remaining == 0 {
- ndp.rtrSolicitTimer = nil
- } else if ndp.rtrSolicitTimer != nil {
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
// we still reached this point, then we know the NIC
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
- ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
+ ndp.nic.mu.Unlock()
})
}
@@ -1913,13 +1943,15 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The NIC ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicitTimer == nil {
+ if ndp.rtrSolicit.timer == nil {
// Nothing to do.
return
}
- ndp.rtrSolicitTimer.Stop()
- ndp.rtrSolicitTimer = nil
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
}
// initializeTempAddrState initializes state related to temporary SLAAC
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index b3d174cdd..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -36,15 +36,24 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
- defaultTimeout = 100 * time.Millisecond
- defaultAsyncEventTimeout = time.Second
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -421,45 +430,90 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ // We add a default route so the call to FindRoute below will succeed
+ // once we have an assigned address.
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr3,
+ NIC: nicID,
+ }})
+
if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ // Should not get a route even if we specify the local address as the
+ // tentative address.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Wait for DAD to resolve.
select {
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if addr.Address != addr1 {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
}
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ // Should get a route using the address now that it is resolved.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Should not have sent any more NS messages.
@@ -613,7 +667,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -935,7 +989,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -970,14 +1024,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -986,7 +1040,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -994,7 +1048,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -1003,7 +1057,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
@@ -1124,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1200,14 +1254,14 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1217,14 +1271,14 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1373,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1448,14 +1502,14 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1520,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with finite lifetime.
@@ -1545,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with a prefix with a lifetime value greater than the
@@ -1554,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with 0 lifetime.
@@ -1790,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1917,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -1930,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2036,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
@@ -2135,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -2143,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -2220,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2228,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2318,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2341,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2353,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
@@ -2362,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event = %+v", e)
}
- case <-time.After(invalidateAfter + defaultAsyncEventTimeout):
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
@@ -2378,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2472,14 +2526,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Prefer the prefix again.
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2501,24 +2555,24 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2666,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2679,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2939,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3025,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3065,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3079,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3111,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -3120,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3250,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3394,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
}
// Wait for the invalidation event.
@@ -3403,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3459,12 +3513,12 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -3627,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -3725,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3792,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3818,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -3985,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4104,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4206,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -4232,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4240,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4824,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
select {
case <-ndpDisp.routerC:
t.Error("unexpected router event")
@@ -5127,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
- waitForPkt(2 * defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
} else {
- waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
}
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
}
// Make sure the counter got properly
@@ -5260,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -5274,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
@@ -5287,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -5299,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
new file mode 100644
index 000000000..1d37716c2
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -0,0 +1,335 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const neighborCacheSize = 512 // max entries per interface
+
+// neighborCache maps IP addresses to link addresses. It uses the Least
+// Recently Used (LRU) eviction strategy to implement a bounded cache for
+// dynmically acquired entries. It contains the state machine and configuration
+// for running Neighbor Unreachability Detection (NUD).
+//
+// There are two types of entries in the neighbor cache:
+// 1. Dynamic entries are discovered automatically by neighbor discovery
+// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
+// reachability with the device once the entry's state becomes Stale.
+// 2. Static entries are explicitly added by a user and have no expiration.
+// Their state is always Static. The amount of static entries stored in the
+// cache is unbounded.
+//
+// neighborCache implements NUDHandler.
+type neighborCache struct {
+ nic *NIC
+ state *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ cache map[tcpip.Address]*neighborEntry
+ dynamic struct {
+ lru neighborEntryList
+
+ // count tracks the amount of dynamic entries in the cache. This is
+ // needed since static entries do not count towards the LRU cache
+ // eviction strategy.
+ count uint16
+ }
+}
+
+var _ NUDHandler = (*neighborCache)(nil)
+
+// getOrCreateEntry retrieves a cache entry associated with addr. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[remoteAddr]; ok {
+ entry.mu.RLock()
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.lru.PushFront(entry)
+ }
+ entry.mu.RUnlock()
+ return entry
+ }
+
+ // The entry that needs to be created must be dynamic since all static
+ // entries are directly added to the cache via addStaticEntry.
+ entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ if n.dynamic.count == neighborCacheSize {
+ e := n.dynamic.lru.Back()
+ e.mu.Lock()
+
+ delete(n.cache, e.neigh.Addr)
+ n.dynamic.lru.Remove(e)
+ n.dynamic.count--
+
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Unknown)
+ e.notifyWakersLocked()
+ e.mu.Unlock()
+ }
+ n.cache[remoteAddr] = entry
+ n.dynamic.lru.PushFront(entry)
+ n.dynamic.count++
+ return entry
+}
+
+// entry looks up the neighbor cache for translating address to link address
+// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
+// is a LinkAddressResolver registered with the network protocol, the cache
+// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
+// provided, it will be notified when address resolution is complete (success
+// or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification
+// channel is returned for the top level caller to block. Channel is closed
+// once address resolution is complete (success or not).
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ if linkRes != nil {
+ if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
+ e := NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ }
+ return e, nil, nil
+ }
+ }
+
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ switch s := entry.neigh.State; s {
+ case Reachable, Static:
+ return entry.neigh, nil, nil
+
+ case Unknown, Incomplete, Stale, Delay, Probe:
+ entry.addWakerLocked(w)
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+ }
+ entry.done = make(chan struct{})
+ }
+
+ entry.handlePacketQueuedLocked()
+ return entry.neigh, entry.done, tcpip.ErrWouldBlock
+
+ case Failed:
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", s))
+ }
+}
+
+// removeWaker removes a waker that has been added when link resolution for
+// addr was requested.
+func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
+ n.mu.Lock()
+ if entry, ok := n.cache[addr]; ok {
+ delete(entry.wakers, waker)
+ }
+ n.mu.Unlock()
+}
+
+// entries returns all entries in the neighbor cache.
+func (n *neighborCache) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(n.cache))
+ n.mu.RLock()
+ for _, entry := range n.cache {
+ entry.mu.RLock()
+ entries = append(entries, entry.neigh)
+ entry.mu.RUnlock()
+ }
+ n.mu.RUnlock()
+ return entries
+}
+
+// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
+// address to a link address. If a dynamic entry exists in the neighbor cache
+// with the same address, it will be replaced with this static entry. If a
+// static entry exists with the same address but different link address, it
+// will be updated with the new link address. If a static entry exists with the
+// same address and link address, nothing will happen.
+func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[addr]; ok {
+ entry.mu.Lock()
+ if entry.neigh.State != Static {
+ // Dynamic entry found with the same address.
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ } else if entry.neigh.LinkAddr == linkAddr {
+ // Static entry found with the same address and link address.
+ entry.mu.Unlock()
+ return
+ } else {
+ // Static entry found with the same address but different link address.
+ entry.neigh.LinkAddr = linkAddr
+ entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.mu.Unlock()
+ return
+ }
+
+ // Notify that resolution has been interrupted, just in case the entry was
+ // in the Incomplete or Probe state.
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
+ n.cache[addr] = entry
+}
+
+// removeEntryLocked removes the specified entry from the neighbor cache.
+func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+ if entry.neigh.State != Failed {
+ entry.dispatchRemoveEventLocked()
+ }
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+
+ delete(n.cache, entry.neigh.Addr)
+}
+
+// removeEntry removes a dynamic or static entry by address from the neighbor
+// cache. Returns true if the entry was found and deleted.
+func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ entry, ok := n.cache[addr]
+ if !ok {
+ return false
+ }
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ n.removeEntryLocked(entry)
+ return true
+}
+
+// clear removes all dynamic and static entries from the neighbor cache.
+func (n *neighborCache) clear() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ for _, entry := range n.cache {
+ entry.mu.Lock()
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ n.dynamic.lru = neighborEntryList{}
+ n.cache = make(map[tcpip.Address]*neighborEntry)
+ n.dynamic.count = 0
+}
+
+// config returns the NUD configuration.
+func (n *neighborCache) config() NUDConfigurations {
+ return n.state.Config()
+}
+
+// setConfig changes the NUD configuration.
+//
+// If config contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *neighborCache) setConfig(config NUDConfigurations) {
+ config.resetInvalidFields()
+ n.state.SetConfig(config)
+}
+
+// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
+// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
+// by the caller.
+func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) {
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, nil)
+ entry.mu.Lock()
+ entry.handleProbeLocked(remoteLinkAddr)
+ entry.mu.Unlock()
+}
+
+// HandleConfirmation implements NUDHandler.HandleConfirmation by following the
+// logic defined in RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce cryptographically
+// generated addresses, as defined in RFC 3972, Cryptographically Generated
+// Addresses (CGA). This ensures that the claimed source of an NDP message is
+// the owner of the claimed address.
+func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleConfirmationLocked(linkAddr, flags)
+ entry.mu.Unlock()
+ }
+ // The confirmation SHOULD be silently discarded if the recipient did not
+ // initiate any communication with the target. This is indicated if there is
+ // no matching entry for the remote address.
+}
+
+// HandleUpperLevelConfirmation implements
+// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
+// RFC 4861 section 7.3.1.
+func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleUpperLevelConfirmationLocked()
+ entry.mu.Unlock()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
new file mode 100644
index 000000000..4cb2c9c6b
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -0,0 +1,1752 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // entryStoreSize is the default number of entries that will be generated and
+ // added to the entry store. This number needs to be larger than the size of
+ // the neighbor cache to give ample opportunity for verifying behavior during
+ // cache overflows. Four times the size of the neighbor cache allows for
+ // three complete cache overflows.
+ entryStoreSize = 4 * neighborCacheSize
+
+ // typicalLatency is the typical latency for an ARP or NDP packet to travel
+ // to a router and back.
+ typicalLatency = time.Millisecond
+
+ // testEntryBroadcastAddr is a special address that indicates a packet should
+ // be sent to all nodes.
+ testEntryBroadcastAddr = tcpip.Address("broadcast")
+
+ // testEntryLocalAddr is the source address of neighbor probes.
+ testEntryLocalAddr = tcpip.Address("local_addr")
+
+ // testEntryBroadcastLinkAddr is a special link address sent back to
+ // multicast neighbor probes.
+ testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
+
+ // infiniteDuration indicates that a task will not occur in our lifetime.
+ infiniteDuration = time.Duration(math.MaxInt64)
+)
+
+// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
+// entries. The UpdatedAt field is ignored due to a lack of a deterministic
+// method to predict the time that an event will be dispatched.
+func entryDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ }
+}
+
+// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
+// sort slices of entries for cases where ordering must be ignored.
+func entryDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
+ config.resetInvalidFields()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ return &neighborCache{
+ nic: &NIC{
+ stack: &Stack{
+ clock: clock,
+ nudDisp: nudDisp,
+ },
+ id: 1,
+ },
+ state: NewNUDState(config, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+}
+
+// testEntryStore contains a set of IP to NeighborEntry mappings.
+type testEntryStore struct {
+ mu sync.RWMutex
+ entriesMap map[tcpip.Address]NeighborEntry
+}
+
+func toAddress(i int) tcpip.Address {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint16(i))
+ return tcpip.Address(buf.String())
+}
+
+func toLinkAddress(i int) tcpip.LinkAddress {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint32(i))
+ return tcpip.LinkAddress(buf.String())
+}
+
+// newTestEntryStore returns a testEntryStore pre-populated with entries.
+func newTestEntryStore() *testEntryStore {
+ store := &testEntryStore{
+ entriesMap: make(map[tcpip.Address]NeighborEntry),
+ }
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ linkAddr := toLinkAddress(i)
+
+ store.entriesMap[addr] = NeighborEntry{
+ Addr: addr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: linkAddr,
+ }
+ }
+ return store
+}
+
+// size returns the number of entries in the store.
+func (s *testEntryStore) size() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return len(s.entriesMap)
+}
+
+// entry returns the entry at index i. Returns an empty entry and false if i is
+// out of bounds.
+func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+ return s.entryByAddr(toAddress(i))
+}
+
+// entryByAddr returns the entry matching addr for situations when the index is
+// not available. Returns an empty entry and false if no entries match addr.
+func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ entry, ok := s.entriesMap[addr]
+ return entry, ok
+}
+
+// entries returns all entries in the store.
+func (s *testEntryStore) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(s.entriesMap))
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ if entry, ok := s.entriesMap[addr]; ok {
+ entries = append(entries, entry)
+ }
+ }
+ return entries
+}
+
+// set modifies the link addresses of an entry.
+func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+ addr := toAddress(i)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if entry, ok := s.entriesMap[addr]; ok {
+ entry.LinkAddr = linkAddr
+ s.entriesMap[addr] = entry
+ }
+}
+
+// testNeighborResolver implements LinkAddressResolver to emulate sending a
+// neighbor probe.
+type testNeighborResolver struct {
+ clock tcpip.Clock
+ neigh *neighborCache
+ entries *testEntryStore
+ delay time.Duration
+ onLinkAddressRequest func()
+}
+
+var _ LinkAddressResolver = (*testNeighborResolver)(nil)
+
+func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(addr)
+ })
+
+ // Execute post address resolution action, if available.
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
+ return nil
+}
+
+// fakeRequest emulates handling a response for a link address request.
+func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) {
+ if entry, ok := r.entries.entryByAddr(addr); ok {
+ r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+}
+
+func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == testEntryBroadcastAddr {
+ return testEntryBroadcastLinkAddr, true
+ }
+ return "", false
+}
+
+func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 0
+}
+
+type entryEvent struct {
+ nicID tcpip.NICID
+ address tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state NeighborState
+}
+
+func TestNeighborCacheGetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheSetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+ neigh.setConfig(c)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheEntry(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a
+// LinkAddressResolver returns ErrNoLinkAddress.
+func TestNeighborCacheEntryNoLinkAddress(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil)
+ if err != tcpip.ErrNoLinkAddress {
+ t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveEntry(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+}
+
+type testContext struct {
+ clock *fakeClock
+ neigh *neighborCache
+ store *testEntryStore
+ linkRes *testNeighborResolver
+ nudDisp *testNUDDispatcher
+}
+
+func newTestContext(c NUDConfigurations) testContext {
+ nudDisp := &testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ return testContext{
+ clock: clock,
+ neigh: neigh,
+ store: store,
+ linkRes: linkRes,
+ nudDisp: nudDisp,
+ }
+}
+
+type overflowOptions struct {
+ startAtEntryIndex int
+ wantStaticEntries []NeighborEntry
+}
+
+func (c *testContext) overflowCache(opts overflowOptions) error {
+ // Fill the neighbor cache to capacity to verify the LRU eviction strategy is
+ // working properly after the entry removal.
+ for i := opts.startAtEntryIndex; i < c.store.size(); i++ {
+ // Add a new entry
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+
+ var wantEvents []testEntryEventInfo
+
+ // When beyond the full capacity, the cache will evict an entry as per the
+ // LRU eviction strategy. Note that the number of static entries should not
+ // affect the total number of dynamic entries that can be added.
+ if i >= neighborCacheSize+opts.startAtEntryIndex {
+ removedEntry, ok := c.store.entry(i - neighborCacheSize)
+ if !ok {
+ return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize)
+ }
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ })
+ }
+
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ }, testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ })
+
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the most recent entries. The order of entries reported
+ // by entries() is undeterministic, so entries have to be sorted before
+ // comparison.
+ wantUnsortedEntries := opts.wantStaticEntries
+ for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ return nil
+}
+
+// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy
+// respects the dynamic entry count.
+func TestNeighborCacheOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when an entry is removed.
+func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the entry
+ c.neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that
+// adding a duplicate static entry with the same link address does not dispatch
+// any events.
+func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that
+// adding a duplicate static entry with a different link address dispatches a
+// change event.
+func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a duplicate entry with a different link address
+ staticLinkAddr += "duplicate"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when a static entry is
+// added then removed. In this case, the dynamic entry count shouldn't have
+// been touched.
+func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.removeEntry(entry.Addr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU
+// cache eviction strategy keeps count of the dynamic entry count when an entry
+// is overwritten by a static entry. Static entries should not count towards
+// the size of the LRU cache.
+func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Override the entry with a static one using the same address
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheNotifiesWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ id, ok := s.Fetch(false /* block */)
+ if !ok {
+ t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ if id != wakerID {
+ t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Remove the waker before the neighbor cache has the opportunity to send a
+ // notification.
+ neigh.removeWaker(entry.Addr, &w)
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Errorf("unexpected notification from waker with id %d", id)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
+ e, _, err := c.neigh.entry(entry.Addr, "", nil, nil)
+ if err != nil {
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheClear(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add a dynamic entry.
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a static entry.
+ neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Clear shoud remove both dynamic and static entries.
+ neigh.clear()
+
+ // Remove events dispatched from clear() have no deterministic order so they
+ // need to be sorted beforehand.
+ wantUnsortedEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction
+// strategy keeps count of the dynamic entry count when all entries are
+// cleared.
+func TestNeighborCacheClearThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Clear the cache.
+ c.neigh.clear()
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ frequentlyUsedEntry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+
+ // The following logic is very similar to overflowCache, but
+ // periodically refreshes the frequently used entry.
+
+ // Fill the neighbor cache to capacity
+ for i := 0; i < neighborCacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Keep adding more entries
+ for i := neighborCacheSize; i < store.size(); i++ {
+ // Periodically refresh the frequently used entry
+ if i%(neighborCacheSize/2) == 0 {
+ _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ }
+ }
+
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // An entry should have been removed, as per the LRU eviction strategy
+ removedEntry, ok := store.entry(i - neighborCacheSize + 1)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the frequently used entry and the most recent entries.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ wantUnsortedEntries := []NeighborEntry{
+ {
+ Addr: frequentlyUsedEntry.Addr,
+ LocalAddr: frequentlyUsedEntry.LocalAddr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
+ },
+ }
+
+ for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheConcurrent(t *testing.T) {
+ const concurrentProcesses = 16
+
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ storeEntries := store.entries()
+ for _, entry := range storeEntries {
+ var wg sync.WaitGroup
+ for r := 0; r < concurrentProcesses; r++ {
+ wg.Add(1)
+ go func(entry NeighborEntry) {
+ defer wg.Done()
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ }
+ }(entry)
+ }
+
+ // Wait for all gorountines to send a request
+ wg.Wait()
+
+ // Process all the requests for a single entry concurrently
+ clock.advance(typicalLatency)
+ }
+
+ // All goroutines add in the same order and add more values than can fit in
+ // the cache. Our eviction strategy requires that the last entries are
+ // present, up to the size of the neighbor cache, and the rest are missing.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ var wantUnsortedEntries []NeighborEntry
+ for i := store.size() - neighborCacheSize; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Errorf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheReplace(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add an entry
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Verify the entry exists
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
+
+ // Notify of a link address change
+ var updatedLinkAddr tcpip.LinkAddress
+ {
+ entry, ok := store.entry(1)
+ if !ok {
+ t.Fatalf("store.entry(1) not found")
+ }
+ updatedLinkAddr = entry.LinkAddr
+ }
+ store.set(0, updatedLinkAddr)
+ neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+
+ // Requesting the entry again should start address resolution
+ {
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ }
+
+ // Verify the entry's new link address
+ {
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ clock.advance(typicalLatency)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want = NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ }
+}
+
+func TestNeighborCacheResolutionFailed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+
+ var requestCount uint32
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ onLinkAddressRequest: func() {
+ atomic.AddUint32(&requestCount, 1)
+ },
+ }
+
+ // First, sanity check that resolution is working
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+
+ // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ before := atomic.LoadUint32(&requestCount)
+
+ entry.Addr += "2"
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ maxAttempts := neigh.config().MaxUnicastProbes
+ if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
+}
+
+// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes
+// probes and not retrieving a confirmation before the duration defined by
+// MaxMulticastProbes * RetransmitTimer.
+func TestNeighborCacheResolutionTimeout(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ config.RetransmitTimer = time.Millisecond // small enough to cause timeout
+
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: time.Minute, // large enough to cause timeout
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+}
+
+// TestNeighborCacheStaticResolution checks that static link addresses are
+// resolved immediately and don't send resolution requests.
+func TestNeighborCacheStaticResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: testEntryBroadcastAddr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ }
+}
+
+func BenchmarkCacheClear(b *testing.B) {
+ b.StopTimer()
+ config := DefaultNUDConfigurations()
+ clock := &tcpip.StdClock{}
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: 0,
+ }
+
+ // Clear for every possible size of the cache
+ for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ // Fill the neighbor cache to capacity.
+ for i := 0; i < cacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ b.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh != nil {
+ <-doneCh
+ }
+ }
+
+ b.StartTimer()
+ neigh.clear()
+ b.StopTimer()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
new file mode 100644
index 000000000..0068cacb8
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -0,0 +1,482 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// NeighborEntry describes a neighboring device in the local network.
+type NeighborEntry struct {
+ Addr tcpip.Address
+ LocalAddr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+// NeighborState defines the state of a NeighborEntry within the Neighbor
+// Unreachability Detection state machine, as per RFC 4861 section 7.3.2.
+type NeighborState uint8
+
+const (
+ // Unknown means reachability has not been verified yet. This is the initial
+ // state of entries that have been created automatically by the Neighbor
+ // Unreachability Detection state machine.
+ Unknown NeighborState = iota
+ // Incomplete means that there is an outstanding request to resolve the
+ // address.
+ Incomplete
+ // Reachable means the path to the neighbor is functioning properly for both
+ // receive and transmit paths.
+ Reachable
+ // Stale means reachability to the neighbor is unknown, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Stale
+ // Delay means reachability to the neighbor is unknown and pending
+ // confirmation from an upper-level protocol like TCP, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Delay
+ // Probe means a reachability confirmation is actively being sought by
+ // periodically retransmitting reachability probes until a reachability
+ // confirmation is received, or until the max amount of probes has been sent.
+ Probe
+ // Static describes entries that have been explicitly added by the user. They
+ // do not expire and are not deleted until explicitly removed.
+ Static
+ // Failed means traffic should not be sent to this neighbor since attempts of
+ // reachability have returned inconclusive.
+ Failed
+)
+
+// neighborEntry implements a neighbor entry's individual node behavior, as per
+// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
+// parallel with the sending of packets to a neighbor, necessitating the
+// entry's lock to be acquired for all operations.
+type neighborEntry struct {
+ neighborEntryEntry
+
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkRes provides the functionality to send reachability probes, used in
+ // Neighbor Unreachability Detection.
+ linkRes LinkAddressResolver
+
+ // nudState points to the Neighbor Unreachability Detection configuration.
+ nudState *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ neigh NeighborEntry
+
+ // wakers is a set of waiters for address resolution result. Anytime state
+ // transitions out of incomplete these waiters are notified. It is nil iff
+ // address resolution is ongoing and no clients are waiting for the result.
+ wakers map[*sleep.Waker]struct{}
+
+ // done is used to allow callers to wait on address resolution. It is nil
+ // iff nudState is not Reachable and address resolution is not yet in
+ // progress.
+ done chan struct{}
+
+ isRouter bool
+ job *tcpip.Job
+}
+
+// newNeighborEntry creates a neighbor cache entry starting at the default
+// state, Unknown. Transition out of Unknown by calling either
+// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
+// neighborEntry.
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+ return &neighborEntry{
+ nic: nic,
+ linkRes: linkRes,
+ nudState: nudState,
+ neigh: NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ State: Unknown,
+ },
+ }
+}
+
+// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
+// state. The entry can only transition out of Static by directly calling
+// `setStateLocked`.
+func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ if nic.stack.nudDisp != nil {
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ }
+ return &neighborEntry{
+ nic: nic,
+ nudState: state,
+ neigh: NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ },
+ }
+}
+
+// addWaker adds w to the list of wakers waiting for address resolution.
+// Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
+ if w == nil {
+ return
+ }
+ if e.wakers == nil {
+ e.wakers = make(map[*sleep.Waker]struct{})
+ }
+ e.wakers[w] = struct{}{}
+}
+
+// notifyWakersLocked notifies those waiting for address resolution, whether it
+// succeeded or failed. Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) notifyWakersLocked() {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
+// been added.
+func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
+// has changed state or link-layer address.
+func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
+// has been removed.
+func (e *neighborEntry) dispatchRemoveEventLocked() {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ }
+}
+
+// setStateLocked transitions the entry to the specified state immediately.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// e.mu MUST be locked.
+func (e *neighborEntry) setStateLocked(next NeighborState) {
+ // Cancel the previously scheduled action, if there is one. Entries in
+ // Unknown, Stale, or Static state do not have scheduled actions.
+ if timer := e.job; timer != nil {
+ timer.Cancel()
+ }
+
+ prev := e.neigh.State
+ e.neigh.State = next
+ e.neigh.UpdatedAt = time.Now()
+ config := e.nudState.Config()
+
+ switch next {
+ case Incomplete:
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendMulticastProbe()
+
+ case Reachable:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ })
+ e.job.Schedule(e.nudState.ReachableTime())
+
+ case Delay:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Probe)
+ e.setStateLocked(Probe)
+ })
+ e.job.Schedule(config.DelayFirstProbeTime)
+
+ case Probe:
+ var retryCounter uint32
+ var sendUnicastProbe func()
+
+ sendUnicastProbe = func() {
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendUnicastProbe()
+
+ case Failed:
+ e.notifyWakersLocked()
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.nic.neigh.removeEntryLocked(e)
+ })
+ e.job.Schedule(config.UnreachableTime)
+
+ case Unknown, Stale, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
+ }
+}
+
+// handlePacketQueuedLocked advances the state machine according to a packet
+// being queued for outgoing transmission.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+func (e *neighborEntry) handlePacketQueuedLocked() {
+ switch e.neigh.State {
+ case Unknown:
+ e.dispatchAddEventLocked(Incomplete)
+ e.setStateLocked(Incomplete)
+
+ case Stale:
+ e.dispatchChangeEventLocked(Delay)
+ e.setStateLocked(Delay)
+
+ case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
+// Neighbor Solicitation for ARP or NDP, respectively).
+//
+// Follows the logic defined in RFC 4861 section 7.2.3.
+func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
+ // Probes MUST be silently discarded if the target address is tentative, does
+ // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
+ // checks MUST be done by the NetworkEndpoint.
+
+ switch e.neigh.State {
+ case Unknown, Incomplete, Failed:
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchAddEventLocked(Stale)
+ e.setStateLocked(Stale)
+ e.notifyWakersLocked()
+
+ case Reachable, Delay, Probe:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+
+ case Stale:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ }
+
+ case Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleConfirmationLocked processes an incoming neighbor confirmation
+// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
+//
+// Follows the state machine defined by RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce Cryptographically
+// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
+// claimed source of an NDP message is the owner of the claimed address.
+func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ switch e.neigh.State {
+ case Incomplete:
+ if len(linkAddr) == 0 {
+ // "If the link layer has addresses and no Target Link-Layer Address
+ // option is included, the receiving node SHOULD silently discard the
+ // received advertisement." - RFC 4861 section 7.2.5
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+ if flags.Solicited {
+ e.dispatchChangeEventLocked(Reachable)
+ e.setStateLocked(Reachable)
+ } else {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ e.isRouter = flags.IsRouter
+ e.notifyWakersLocked()
+
+ // "Note that the Override flag is ignored if the entry is in the
+ // INCOMPLETE state." - RFC 4861 section 7.2.5
+
+ case Reachable, Stale, Delay, Probe:
+ sameLinkAddr := e.neigh.LinkAddr == linkAddr
+
+ if !sameLinkAddr {
+ if !flags.Override {
+ if e.neigh.State == Reachable {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+
+ if !flags.Solicited {
+ if e.neigh.State != Stale {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ } else {
+ // Notify the LinkAddr change, even though NUD state hasn't changed.
+ e.dispatchChangeEventLocked(e.neigh.State)
+ }
+ break
+ }
+ }
+
+ if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ }
+ // Set state to Reachable again to refresh timers.
+ e.setStateLocked(Reachable)
+ e.notifyWakersLocked()
+ }
+
+ if e.isRouter && !flags.IsRouter {
+ // "In those cases where the IsRouter flag changes from TRUE to FALSE as
+ // a result of this update, the node MUST remove that router from the
+ // Default Router List and update the Destination Cache entries for all
+ // destinations using that neighbor as a router as specified in Section
+ // 7.3.3. This is needed to detect when a node that is used as a router
+ // stops forwarding packets due to being configured as a host."
+ // - RFC 4861 section 7.2.5
+ e.nic.mu.Lock()
+ e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr)
+ e.nic.mu.Unlock()
+ }
+ e.isRouter = flags.IsRouter
+
+ case Unknown, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
+// (e.g. TCP acknowledgements) reachability confirmation.
+func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
+ switch e.neigh.State {
+ case Reachable, Stale, Delay, Probe:
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ // Set state to Reachable again to refresh timers.
+ }
+ e.setStateLocked(Reachable)
+
+ case Unknown, Incomplete, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
new file mode 100644
index 000000000..08c9ccd25
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -0,0 +1,2770 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+
+ entryTestNICID tcpip.NICID = 1
+ entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
+ entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
+
+ // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match the
+ // MTU of loopback interfaces on Linux systems.
+ entryTestNetDefaultMTU = 65536
+)
+
+// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
+// The UpdatedAt field is ignored due to a lack of a deterministic method to
+// predict the time that an event will be dispatched.
+func eventDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ }
+}
+
+// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
+// sort slices of events for cases where ordering must be ignored.
+func eventDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+// The following unit tests exercise every state transition and verify its
+// behavior with RFC 4681.
+//
+// | From | To | Cause | Action | Event |
+// | ========== | ========== | ========================================== | =============== | ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
+// | Unknown | Stale | Probe w/ unknown address | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
+// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
+// | Reachable | Stale | Reachable timer expired | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
+// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
+// | Stale | Delay | Packet sent | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | Changed |
+// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
+// | Delay | Probe | Delay timer expired | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
+// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
+// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | | Unreachability timer expired | Delete entry | |
+
+type testEntryEventType uint8
+
+const (
+ entryTestAdded testEntryEventType = iota
+ entryTestChanged
+ entryTestRemoved
+)
+
+func (t testEntryEventType) String() string {
+ switch t {
+ case entryTestAdded:
+ return "add"
+ case entryTestChanged:
+ return "change"
+ case entryTestRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+// Fields are exported for use with cmp.Diff.
+type testEntryEventInfo struct {
+ EventType testEntryEventType
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+func (e testEntryEventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+}
+
+// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type testNUDDispatcher struct {
+ mu sync.Mutex
+ events []testEntryEventInfo
+}
+
+var _ NUDDispatcher = (*testNUDDispatcher)(nil)
+
+func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.events = append(d.events, e)
+}
+
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+type entryTestLinkResolver struct {
+ mu sync.Mutex
+ probes []entryTestProbeInfo
+}
+
+var _ LinkAddressResolver = (*entryTestLinkResolver)(nil)
+
+type entryTestProbeInfo struct {
+ RemoteAddress tcpip.Address
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalAddress tcpip.Address
+}
+
+func (p entryTestProbeInfo) String() string {
+ return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress)
+}
+
+// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+// to the local network if linkAddr is the zero value.
+func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ p := entryTestProbeInfo{
+ RemoteAddress: addr,
+ RemoteLinkAddress: linkAddr,
+ LocalAddress: localAddr,
+ }
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.probes = append(r.probes, p)
+ return nil
+}
+
+// ResolveStaticAddress attempts to resolve address without sending requests.
+// It either resolves the name immediately or returns the empty LinkAddress.
+func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ return "", false
+}
+
+// LinkAddressProtocol returns the network protocol of the addresses this
+// resolver can resolve.
+func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return entryTestNetNumber
+}
+
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
+ clock := newFakeClock()
+ disp := testNUDDispatcher{}
+ nic := NIC{
+ id: entryTestNICID,
+ linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ stack: &Stack{
+ clock: clock,
+ nudDisp: &disp,
+ },
+ }
+
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ nudState := NewNUDState(c, rng)
+ linkRes := entryTestLinkResolver{}
+ entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes)
+
+ // Stub out ndpState to verify modification of default routers.
+ nic.mu.ndp = ndpState{
+ nic: &nic,
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ }
+
+ // Stub out the neighbor cache to verify deletion from the cache.
+ nic.neigh = &neighborCache{
+ nic: &nic,
+ state: nudState,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ nic.neigh.cache[entryTestAddr1] = entry
+
+ return entry, &disp, &linkRes, clock
+}
+
+// TestEntryInitiallyUnknown verifies that the state of a newly created
+// neighborEntry is Unknown.
+func TestEntryInitiallyUnknown(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(time.Hour)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToIncomplete(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ {
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+func TestEntryUnknownToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ updatedAt := e.neigh.UpdatedAt
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should remain the same during address resolution.
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should change after failing address resolution. Timing out after
+ // sending the last probe transitions the entry to Failed.
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ clock.advance(c.RetransmitTimer)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryAddsAndClearsWakers verifies that wakers are added when
+// addWakerLocked is called and cleared when address resolution finishes. In
+// this case, address resolution will finish when transitioning from Incomplete
+// to Reachable.
+func TestEntryAddsAndClearsWakers(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got := e.wakers; got != nil {
+ t.Errorf("got e.wakers = %v, want = nil", got)
+ }
+ e.addWakerLocked(&w)
+ if got, want := w.IsAsserted(), false; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ if e.wakers == nil {
+ t.Error("expected e.wakers to be non-nil")
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.wakers != nil {
+ t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ }
+ if got, want := w.IsAsserted(), true; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ linkRes.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+type testLocker struct{}
+
+var _ sync.Locker = (*testLocker)(nil)
+
+func (*testLocker) Lock() {}
+func (*testLocker) Unlock() {}
+
+func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{
+ invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}),
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.isRouter, false; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok {
+ t.Errorf("unexpected defaultRouter for %s", entryTestAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToDelay(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToProbe(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryFailedGetsDeleted(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ // Verify the cache no longer contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
+ t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
+ }
+}
diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go
new file mode 100644
index 000000000..aa7311ec6
--- /dev/null
+++ b/pkg/tcpip/stack/neighborstate_string.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NeighborState"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Unknown-0]
+ _ = x[Incomplete-1]
+ _ = x[Reachable-2]
+ _ = x[Stale-3]
+ _ = x[Delay-4]
+ _ = x[Probe-5]
+ _ = x[Static-6]
+ _ = x[Failed-7]
+}
+
+const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed"
+
+var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53}
+
+func (i NeighborState) String() string {
+ if i >= NeighborState(len(_NeighborState_index)-1) {
+ return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 54103fdb3..f21066fce 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "math/rand"
"reflect"
"sort"
"strings"
@@ -45,6 +46,7 @@ type NIC struct {
context NICContext
stats NICStats
+ neigh *neighborCache
mu struct {
sync.RWMutex
@@ -141,6 +143,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ // Check for Neighbor Unreachability Detection support.
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 {
+ rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
+ nic.neigh = &neighborCache{
+ nic: nic,
+ state: NewNUDState(stack.nudConfigs, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ }
+
nic.linkEP.Attach(nic)
return nic
@@ -181,7 +193,7 @@ func (n *NIC) disableLocked() *tcpip.Error {
return nil
}
- // TODO(b/147015577): Should Routes that are currently bound to n be
+ // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
// invalidated? Currently, Routes will continue to work when a NIC is enabled
// again, and applications may not know that the underlying NIC was ever
// disabled.
@@ -457,8 +469,20 @@ type ipv6AddrCandidate struct {
// remoteAddr must be a valid IPv6 address.
func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
@@ -568,11 +592,6 @@ const (
// promiscuous indicates that the NIC's promiscuous flag should be observed
// when getting a NIC's referenced network endpoint.
promiscuous
-
- // forceSpoofing indicates that the NIC should be assumed to be spoofing,
- // regardless of what the NIC's spoofing flag is when getting a NIC's
- // referenced network endpoint.
- forceSpoofing
)
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
@@ -591,8 +610,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
- id := NetworkEndpointID{address}
-
n.mu.RLock()
var spoofingOrPromiscuous bool
@@ -601,24 +618,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
spoofingOrPromiscuous = n.mu.spoofing
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
- case forceSpoofing:
- spoofingOrPromiscuous = true
}
- if ref, ok := n.mu.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
- switch ref.getKind() {
- case permanentExpired:
- if !spoofingOrPromiscuous {
- n.mu.RUnlock()
- return nil
- }
- fallthrough
- case temporary, permanent:
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
+ if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
+ n.mu.RUnlock()
+ return nil
+ }
+
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
}
}
@@ -654,11 +665,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.mu.endpoints[id]; ok {
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
- n.mu.Unlock()
return ref
}
// tryIncRef failing means the endpoint is scheduled to be removed once the
@@ -670,7 +688,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
@@ -680,8 +697,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
PrefixLen: netProto.DefaultPrefixLen(),
},
}, peb, temporary, static, false)
-
- n.mu.Unlock()
return ref
}
@@ -1153,7 +1168,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1167,7 +1182,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1197,27 +1212,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
- netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
+ // The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(netHeader)
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1229,18 +1251,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
- if protocol == header.IPv4ProtocolNumber {
+ // Loopback traffic skips the prerouting chain.
+ if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok {
+ if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
return
}
}
if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
+ handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
return
}
@@ -1298,24 +1321,55 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
- pkt.Header = buffer.NewPrependable(linkHeaderLen)
+ // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
+ // by having lower layers explicity write each header instead of just
+ // pkt.Header.
+
+ // pkt may have set its NetworkHeader and TransportHeader. If we're
+ // forwarding, we'll have to copy them into pkt.Header.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
+ if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
+ }
+ if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
}
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
}
n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1329,13 +1383,34 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader == nil {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ pkt.TransportHeader = transHeader
+ pkt.Data.TrimFront(len(pkt.TransportHeader))
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
+
+ if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@@ -1362,7 +1437,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -1477,6 +1552,27 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) {
n.mu.Unlock()
}
+// NUDConfigs gets the NUD configurations for n.
+func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) {
+ if n.neigh == nil {
+ return NUDConfigurations{}, tcpip.ErrNotSupported
+ }
+ return n.neigh.config(), nil
+}
+
+// setNUDConfigs sets the NUD configurations for n.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+ c.resetInvalidFields()
+ n.neigh.setConfig(c)
+ return nil
+}
+
// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
n.mu.Lock()
@@ -1611,8 +1707,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in spoofing mode.
+// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// has been removed) unless the NIC is in spoofing mode, or temporary.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
r.nic.mu.RLock()
defer r.nic.mu.RUnlock()
@@ -1620,13 +1716,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
return r.isValidForOutgoingRLocked()
}
-// isValidForOutgoingRLocked returns true if the endpoint can be used to send
-// out a packet. It requires the endpoint to not be marked expired (i.e., its
-// address has been removed), or the NIC to be in spoofing mode.
-//
-// r's NIC must be read locked.
+// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
+// r.nic.mu to be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
+ if !r.nic.mu.enabled {
+ return false
+ }
+
+ return r.isAssignedRLocked(r.nic.mu.spoofing)
+}
+
+// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
+//
+// r.nic.mu must be read locked.
+func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
+ switch r.getKind() {
+ case permanentTentative:
+ return false
+ case permanentExpired:
+ return spoofingOrPromiscuous
+ default:
+ return true
+ }
}
// expireLocked decrements the reference count and marks the permanent endpoint
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index d672fc157..a70792b50 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,11 +15,278 @@
package stack
import (
+ "math"
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ id NetworkEndpointID
+ prefixLen int
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ID implements NetworkEndpoint.ID.
+func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
+ return &e.id
+}
+
+// PrefixLen implements NetworkEndpoint.PrefixLen.
+func (e *testIPv6Endpoint) PrefixLen() int {
+ return e.prefixLen
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
+ prefixLen: addrWithPrefix.PrefixLen,
+ linkEP: linkEP,
+ protocol: p,
+ }, nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
@@ -44,7 +311,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
new file mode 100644
index 000000000..f848d50ad
--- /dev/null
+++ b/pkg/tcpip/stack/nud.go
@@ -0,0 +1,466 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // defaultBaseReachableTime is the default base duration for computing the
+ // random reachable time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of a uniformly distributed random value between the minimum and
+ // maximum random factors, multiplied by the base reachable time. Using a
+ // random component eliminates the possibility that Neighbor Unreachability
+ // Detection messages will synchronize with each other.
+ //
+ // Default taken from REACHABLE_TIME of RFC 4861 section 10.
+ defaultBaseReachableTime = 30 * time.Second
+
+ // minimumBaseReachableTime is the minimum base duration for computing the
+ // random reachable time.
+ //
+ // Minimum = 1ms
+ minimumBaseReachableTime = time.Millisecond
+
+ // defaultMinRandomFactor is the default minimum value of the random factor
+ // used for computing reachable time.
+ //
+ // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10.
+ defaultMinRandomFactor = 0.5
+
+ // defaultMaxRandomFactor is the default maximum value of the random factor
+ // used for computing reachable time.
+ //
+ // The default value depends on the value of MinRandomFactor.
+ // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10,
+ // the value from the RFC will be used; otherwise, the default is
+ // MinRandomFactor multiplied by three.
+ defaultMaxRandomFactor = 1.5
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDelayFirstProbeTime is the default duration to wait for a
+ // non-Neighbor-Discovery related protocol to reconfirm reachability after
+ // entering the DELAY state. After this time, a reachability probe will be
+ // sent and the entry will transition to the PROBE state.
+ //
+ // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10.
+ defaultDelayFirstProbeTime = 5 * time.Second
+
+ // defaultMaxMulticastProbes is the default number of reachabililty probes
+ // to send before concluding negative reachability and deleting the neighbor
+ // entry from the INCOMPLETE state.
+ //
+ // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10.
+ defaultMaxMulticastProbes = 3
+
+ // defaultMaxUnicastProbes is the default number of reachability probes to
+ // send before concluding retransmission from within the PROBE state should
+ // cease and the entry SHOULD be deleted.
+ //
+ // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10.
+ defaultMaxUnicastProbes = 3
+
+ // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD
+ // delay sending a response for a random time between 0 and this time, if the
+ // target address is an anycast address.
+ //
+ // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10.
+ defaultMaxAnycastDelayTime = time.Second
+
+ // defaultMaxReachbilityConfirmations is the default amount of unsolicited
+ // reachability confirmation messages a node MAY send to all-node multicast
+ // address when it determines its link-layer address has changed.
+ //
+ // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
+ defaultMaxReachbilityConfirmations = 3
+
+ // defaultUnreachableTime is the default duration for how long an entry will
+ // remain in the FAILED state before being removed from the neighbor cache.
+ //
+ // Note, there is no equivalent protocol constant defined in RFC 4861. It
+ // leaves the specifics of any garbage collection mechanism up to the
+ // implementation.
+ defaultUnreachableTime = 5 * time.Second
+)
+
+// NUDDispatcher is the interface integrators of netstack must implement to
+// receive and handle NUD related events.
+type NUDDispatcher interface {
+ // OnNeighborAdded will be called when a new entry is added to a NIC's (with
+ // ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
+ // neighbor table changes state and/or link address.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborRemoved will be called when an entry is removed from a NIC's
+ // (with ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+}
+
+// ReachabilityConfirmationFlags describes the flags used within a reachability
+// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP,
+// respectively).
+type ReachabilityConfirmationFlags struct {
+ // Solicited indicates that the advertisement was sent in response to a
+ // reachability probe.
+ Solicited bool
+
+ // Override indicates that the reachability confirmation should override an
+ // existing neighbor cache entry and update the cached link-layer address.
+ // When Override is not set the confirmation will not update a cached
+ // link-layer address, but will update an existing neighbor cache entry for
+ // which no link-layer address is known.
+ Override bool
+
+ // IsRouter indicates that the sender is a router.
+ IsRouter bool
+}
+
+// NUDHandler communicates external events to the Neighbor Unreachability
+// Detection state machine, which is implemented per-interface. This is used by
+// network endpoints to inform the Neighbor Cache of probes and confirmations.
+type NUDHandler interface {
+ // HandleProbe processes an incoming neighbor probe (e.g. ARP request or
+ // Neighbor Solicitation for ARP or NDP, respectively). Validation of the
+ // probe needs to be performed before calling this function since the
+ // Neighbor Cache doesn't have access to view the NIC's assigned addresses.
+ HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress)
+
+ // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
+ // reply or Neighbor Advertisement for ARP or NDP, respectively).
+ HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
+
+ // HandleUpperLevelConfirmation processes an incoming upper-level protocol
+ // (e.g. TCP acknowledgements) reachability confirmation.
+ HandleUpperLevelConfirmation(addr tcpip.Address)
+}
+
+// NUDConfigurations is the NUD configurations for the netstack. This is used
+// by the neighbor cache to operate the NUD state machine on each device in the
+// local network.
+type NUDConfigurations struct {
+ // BaseReachableTime is the base duration for computing the random reachable
+ // time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of uniformly distributed random value between minRandomFactor and
+ // maxRandomFactor multiplied by baseReachableTime. Using a random component
+ // eliminates the possibility that Neighbor Unreachability Detection messages
+ // will synchronize with each other.
+ //
+ // After this time, a neighbor entry will transition from REACHABLE to STALE
+ // state.
+ //
+ // Must be greater than 0.
+ BaseReachableTime time.Duration
+
+ // LearnBaseReachableTime enables learning BaseReachableTime during runtime
+ // from the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option.
+ LearnBaseReachableTime bool
+
+ // MinRandomFactor is the minimum value of the random factor used for
+ // computing reachable time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be greater than 0.
+ MinRandomFactor float32
+
+ // MaxRandomFactor is the maximum value of the random factor used for
+ // computing reachabile time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be great than or equal to MinRandomFactor.
+ MaxRandomFactor float32
+
+ // RetransmitTimer is the duration between retransmission of reachability
+ // probes in the PROBE state.
+ RetransmitTimer time.Duration
+
+ // LearnRetransmitTimer enables learning RetransmitTimer during runtime from
+ // the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option.
+ LearnRetransmitTimer bool
+
+ // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery
+ // related protocol to reconfirm reachability after entering the DELAY state.
+ // After this time, a reachability probe will be sent and the entry will
+ // transition to the PROBE state.
+ //
+ // Must be greater than 0.
+ DelayFirstProbeTime time.Duration
+
+ // MaxMulticastProbes is the number of reachability probes to send before
+ // concluding negative reachability and deleting the neighbor entry from the
+ // INCOMPLETE state.
+ //
+ // Must be greater than 0.
+ MaxMulticastProbes uint32
+
+ // MaxUnicastProbes is the number of reachability probes to send before
+ // concluding retransmission from within the PROBE state should cease and
+ // entry SHOULD be deleted.
+ //
+ // Must be greater than 0.
+ MaxUnicastProbes uint32
+
+ // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a
+ // response for a random time between 0 and this time, if the target address
+ // is an anycast address.
+ //
+ // TODO(gvisor.dev/issue/2242): Use this option when sending solicited
+ // neighbor confirmations to anycast addresses and proxying neighbor
+ // confirmations.
+ MaxAnycastDelayTime time.Duration
+
+ // MaxReachabilityConfirmations is the number of unsolicited reachability
+ // confirmation messages a node MAY send to all-node multicast address when
+ // it determines its link-layer address has changed.
+ //
+ // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
+ // configuration option is necessary.
+ MaxReachabilityConfirmations uint32
+
+ // UnreachableTime describes how long an entry will remain in the FAILED
+ // state before being removed from the neighbor cache.
+ UnreachableTime time.Duration
+}
+
+// DefaultNUDConfigurations returns a NUDConfigurations populated with default
+// values defined by RFC 4861 section 10.
+func DefaultNUDConfigurations() NUDConfigurations {
+ return NUDConfigurations{
+ BaseReachableTime: defaultBaseReachableTime,
+ LearnBaseReachableTime: true,
+ MinRandomFactor: defaultMinRandomFactor,
+ MaxRandomFactor: defaultMaxRandomFactor,
+ RetransmitTimer: defaultRetransmitTimer,
+ LearnRetransmitTimer: true,
+ DelayFirstProbeTime: defaultDelayFirstProbeTime,
+ MaxMulticastProbes: defaultMaxMulticastProbes,
+ MaxUnicastProbes: defaultMaxUnicastProbes,
+ MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
+ MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
+ UnreachableTime: defaultUnreachableTime,
+ }
+}
+
+// resetInvalidFields modifies an invalid NDPConfigurations with valid values.
+// If invalid values are present in c, the corresponding default values will be
+// used instead. This is needed to check, and conditionally fix, user-specified
+// NUDConfigurations.
+func (c *NUDConfigurations) resetInvalidFields() {
+ if c.BaseReachableTime < minimumBaseReachableTime {
+ c.BaseReachableTime = defaultBaseReachableTime
+ }
+ if c.MinRandomFactor <= 0 {
+ c.MinRandomFactor = defaultMinRandomFactor
+ }
+ if c.MaxRandomFactor < c.MinRandomFactor {
+ c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor)
+ }
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+ if c.DelayFirstProbeTime == 0 {
+ c.DelayFirstProbeTime = defaultDelayFirstProbeTime
+ }
+ if c.MaxMulticastProbes == 0 {
+ c.MaxMulticastProbes = defaultMaxMulticastProbes
+ }
+ if c.MaxUnicastProbes == 0 {
+ c.MaxUnicastProbes = defaultMaxUnicastProbes
+ }
+ if c.UnreachableTime == 0 {
+ c.UnreachableTime = defaultUnreachableTime
+ }
+}
+
+// calcMaxRandomFactor calculates the maximum value of the random factor used
+// for computing reachable time. This function is necessary for when the
+// default specified in RFC 4861 section 10 is less than the current
+// MinRandomFactor.
+//
+// Assumes minRandomFactor is positive since validation of the minimum value
+// should come before the validation of the maximum.
+func calcMaxRandomFactor(minRandomFactor float32) float32 {
+ if minRandomFactor > defaultMaxRandomFactor {
+ return minRandomFactor * 3
+ }
+ return defaultMaxRandomFactor
+}
+
+// A Rand is a source of random numbers.
+type Rand interface {
+ // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
+ Float32() float32
+}
+
+// NUDState stores states needed for calculating reachable time.
+type NUDState struct {
+ rng Rand
+
+ // mu protects the fields below.
+ //
+ // It is necessary for NUDState to handle its own locking since neighbor
+ // entries may access the NUD state from within the goroutine spawned by
+ // time.AfterFunc(). This goroutine may run concurrently with the main
+ // process for controlling the neighbor cache and would otherwise introduce
+ // race conditions if NUDState was not locked properly.
+ mu sync.RWMutex
+
+ config NUDConfigurations
+
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
+
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+}
+
+// NewNUDState returns new NUDState using c as configuration and the specified
+// random number generator for use in recomputing ReachableTime.
+func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+ s := &NUDState{
+ rng: rng,
+ }
+ s.config = c
+ return s
+}
+
+// Config returns the NUD configuration.
+func (s *NUDState) Config() NUDConfigurations {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.config
+}
+
+// SetConfig replaces the existing NUD configurations with c.
+func (s *NUDState) SetConfig(c NUDConfigurations) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.config = c
+}
+
+// ReachableTime returns the duration to wait for a REACHABLE entry to
+// transition into STALE after inactivity. This value is recalculated for new
+// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the
+// algorithm defined in RFC 4861 section 6.3.2.
+func (s *NUDState) ReachableTime() time.Duration {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if time.Now().After(s.expiration) ||
+ s.config.BaseReachableTime != s.prevBaseReachableTime ||
+ s.config.MinRandomFactor != s.prevMinRandomFactor ||
+ s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ return s.recomputeReachableTimeLocked()
+ }
+ return s.reachableTime
+}
+
+// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
+// the algorithm defined in RFC 4861 section 6.3.2.
+//
+// This SHOULD automatically be invoked during certain situations, as per
+// RFC 4861 section 6.3.4:
+//
+// If the received Reachable Time value is non-zero, the host SHOULD set its
+// BaseReachableTime variable to the received value. If the new value
+// differs from the previous value, the host SHOULD re-compute a new random
+// ReachableTime value. ReachableTime is computed as a uniformly
+// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR
+// times the BaseReachableTime. Using a random component eliminates the
+// possibility that Neighbor Unreachability Detection messages will
+// synchronize with each other.
+//
+// In most cases, the advertised Reachable Time value will be the same in
+// consecutive Router Advertisements, and a host's BaseReachableTime rarely
+// changes. In such cases, an implementation SHOULD ensure that a new
+// random value gets re-computed at least once every few hours.
+//
+// s.mu MUST be locked for writing.
+func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+ s.prevBaseReachableTime = s.config.BaseReachableTime
+ s.prevMinRandomFactor = s.config.MinRandomFactor
+ s.prevMaxRandomFactor = s.config.MaxRandomFactor
+
+ randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+
+ // Check for overflow, given that minRandomFactor and maxRandomFactor are
+ // guaranteed to be positive numbers.
+ if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
+ s.reachableTime = time.Duration(math.MaxInt64)
+ } else if randomFactor == 1 {
+ // Avoid loss of precision when a large base reachable time is used.
+ s.reachableTime = s.config.BaseReachableTime
+ } else {
+ reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
+ s.reachableTime = time.Duration(reachableTime)
+ }
+
+ s.expiration = time.Now().Add(2 * time.Hour)
+ return s.reachableTime
+}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
new file mode 100644
index 000000000..2494ee610
--- /dev/null
+++ b/pkg/tcpip/stack/nud_test.go
@@ -0,0 +1,795 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
+ defaultMaxAnycastDelayTime = time.Second
+ defaultMaxReachbilityConfirmations = 3
+ defaultUnreachableTime = 5 * time.Second
+
+ defaultFakeRandomNum = 0.5
+)
+
+// fakeRand is a deterministic random number generator.
+type fakeRand struct {
+ num float32
+}
+
+var _ stack.Rand = (*fakeRand)(nil)
+
+func (f *fakeRand) Float32() float32 {
+ return f.num
+}
+
+// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NUD configurations using an invalid NICID.
+func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to retrieve NUD configurations when the
+// stack doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to set NUD configurations when the stack
+// doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestDefaultNUDConfigurationIsValid verifies that calling
+// resetInvalidFields() on the result of DefaultNUDConfigurations() does not
+// change anything. DefaultNUDConfigurations() should return a valid
+// NUDConfigurations.
+func TestDefaultNUDConfigurations(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ c, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got, want := c, stack.DefaultNUDConfigurations(); got != want {
+ t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want)
+ }
+}
+
+func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ baseReachableTime: 0,
+ want: defaultBaseReachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ baseReachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ {
+ name: "MoreThanDefaultBaseReachableTime",
+ baseReachableTime: 2 * defaultBaseReachableTime,
+ want: 2 * defaultBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = test.baseReachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.BaseReachableTime; got != test.want {
+ t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: -1,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: 0,
+ want: defaultMinRandomFactor,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ minRandomFactor: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MinRandomFactor; got != test.want {
+ t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ maxRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: -1,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: 0,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "LessThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 0.99,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault",
+ minRandomFactor: defaultMaxRandomFactor * 2,
+ maxRandomFactor: defaultMaxRandomFactor,
+ want: defaultMaxRandomFactor * 6,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 1.1,
+ want: defaultMinRandomFactor * 1.1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxRandomFactor; got != test.want {
+ t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
+ tests := []struct {
+ name string
+ retransmitTimer time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ retransmitTimer: 0,
+ want: defaultRetransmitTimer,
+ },
+ {
+ name: "LessThanMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer - time.Nanosecond,
+ want: defaultRetransmitTimer,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer,
+ want: minimumBaseReachableTime,
+ },
+ {
+ name: "LargetThanMinimumRetransmitTimer",
+ retransmitTimer: 2 * minimumBaseReachableTime,
+ want: 2 * minimumBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.RetransmitTimer = test.retransmitTimer
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.RetransmitTimer; got != test.want {
+ t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
+ tests := []struct {
+ name string
+ delayFirstProbeTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ delayFirstProbeTime: 0,
+ want: defaultDelayFirstProbeTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ delayFirstProbeTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.DelayFirstProbeTime = test.delayFirstProbeTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.DelayFirstProbeTime; got != test.want {
+ t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxMulticastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxMulticastProbes: 0,
+ want: defaultMaxMulticastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxMulticastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxMulticastProbes = test.maxMulticastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxMulticastProbes; got != test.want {
+ t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxUnicastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxUnicastProbes: 0,
+ want: defaultMaxUnicastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxUnicastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxUnicastProbes = test.maxUnicastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxUnicastProbes; got != test.want {
+ t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsUnreachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ unreachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ unreachableTime: 0,
+ want: defaultUnreachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ unreachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.UnreachableTime = test.unreachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.UnreachableTime; got != test.want {
+ t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+// TestNUDStateReachableTime verifies the correctness of the ReachableTime
+// computation.
+func TestNUDStateReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "AllZeros",
+ baseReachableTime: 0,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMaxRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMinRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 1,
+ want: time.Duration(defaultFakeRandomNum * float32(time.Second)),
+ },
+ {
+ name: "FractionalRandomFactor",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 0.001,
+ maxRandomFactor: 0.002,
+ want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)),
+ },
+ {
+ name: "MinAndMaxRandomFactorsEqual",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Second,
+ },
+ {
+ name: "MinAndMaxRandomFactorsDifferent",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 2,
+ want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)),
+ },
+ {
+ name: "MaxInt64",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "Overflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1.5,
+ maxRandomFactor: 1.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "DoubleOverflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 2.5,
+ maxRandomFactor: 2.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.NUDConfigurations{
+ BaseReachableTime: test.baseReachableTime,
+ MinRandomFactor: test.minRandomFactor,
+ MaxRandomFactor: test.maxRandomFactor,
+ }
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
+
+// TestNUDStateRecomputeReachableTime exercises the ReachableTime function
+// twice to verify recomputation of reachable time when the min random factor,
+// max random factor, or base reachable time changes.
+func TestNUDStateRecomputeReachableTime(t *testing.T) {
+ const defaultBase = time.Second
+ const defaultMin = 2.0 * defaultMaxRandomFactor
+ const defaultMax = 3.0 * defaultMaxRandomFactor
+
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "BaseReachableTime",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMax,
+ want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ {
+ name: "MinRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMax,
+ maxRandomFactor: defaultMax,
+ want: time.Duration(defaultMax * float32(defaultBase)),
+ },
+ {
+ name: "MaxRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMin,
+ want: time.Duration(defaultMin * float32(defaultBase)),
+ },
+ {
+ name: "BothRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)),
+ },
+ {
+ name: "BaseReachableTimeAndBothRandomFactors",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = defaultBase
+ c.MinRandomFactor = defaultMin
+ c.MaxRandomFactor = defaultMax
+
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ old := s.ReachableTime()
+
+ if got, want := s.ReachableTime(), old; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Check for recomputation when changing the min random factor, the max
+ // random factor, the base reachability time, or any permutation of those
+ // three options.
+ c.BaseReachableTime = test.baseReachableTime
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+ s.SetConfig(c)
+
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Verify that ReachableTime isn't recomputed when none of the
+ // configuration options change. The random factor is changed so that if
+ // a recompution were to occur, ReachableTime would change.
+ rng.num = defaultFakeRandomNum / 2.0
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 926df4d7b..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,6 +14,7 @@
package stack
import (
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -24,6 +25,8 @@ import (
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
type PacketBuffer struct {
+ _ sync.NoCopy
+
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
PacketBufferEntry
@@ -76,13 +79,31 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
// VectorisedView but does not deep copy the underlying bytes.
//
// Clone also does not deep copy any of its other fields.
-func (pk PacketBuffer) Clone() PacketBuffer {
- pk.Data = pk.Data.Clone(nil)
- return pk
+//
+// FIXME(b/153685824): Data gets copied but not other header references.
+func (pk *PacketBuffer) Clone() *PacketBuffer {
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ Header: pk.Header,
+ LinkHeader: pk.LinkHeader,
+ NetworkHeader: pk.NetworkHeader,
+ TransportHeader: pk.TransportHeader,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ }
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b331427c6..8604c4259 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -67,12 +71,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +104,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +122,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +154,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -168,6 +172,11 @@ type TransportProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -180,7 +189,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +198,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -240,17 +249,18 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
- // already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
+ // been set.
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
- // protocol. pkts must not be zero length.
+ // protocol. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
- // header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
+ // header to the given destination address. It takes ownership of pkt.
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -265,7 +275,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt PacketBuffer)
+ HandlePacket(r *Route, pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -312,11 +322,18 @@ type NetworkProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -326,7 +343,17 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -382,17 +409,17 @@ type LinkEndpoint interface {
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
- // given route. It sets pkt.LinkHeader if a link layer header exists.
- // pkt.NetworkHeader and pkt.TransportHeader must have already been
- // set.
+ // given route. It takes ownership of pkt. pkt.NetworkHeader and
+ // pkt.TransportHeader must have already been set.
//
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
- // given route. pkts must not be zero length.
+ // given route. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
//
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
@@ -400,7 +427,7 @@ type LinkEndpoint interface {
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header.
+ // should already have an ethernet header. It takes ownership of vv.
WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
@@ -422,6 +449,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -430,7 +466,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
@@ -442,12 +478,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 150297ab9..91e0110f1 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -48,6 +48,10 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+
+ // directedBroadcast indicates whether this route is sending a directed
+ // broadcast packet.
+ directedBroadcast bool
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -113,6 +117,8 @@ func (r *Route) GSOMaxSize() uint32 {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
+//
+// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
@@ -148,22 +154,27 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+
err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
}
return err
}
@@ -175,9 +186,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
+ // WritePackets takes ownership of pkt, calculate length first.
+ numPkts := pkts.Len()
+
n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
}
r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
@@ -193,17 +207,20 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
+ // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Data.Size()
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
return nil
}
@@ -262,6 +279,12 @@ func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+// IsBroadcast returns true if the route is to send a broadcast packet.
+func (r *Route) IsBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+}
+
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0ab4c3e19..3f07e4159 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -52,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -424,12 +424,9 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
- // tablesMu protects iptables.
- tablesMu sync.RWMutex
-
- // tables are the iptables packet filtering and manipulation rules. The are
- // protected by tablesMu.`
- tables IPTables
+ // tables are the iptables packet filtering and manipulation rules.
+ // TODO(gvisor.dev/issue/170): S/R this field.
+ tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -448,6 +445,9 @@ type Stack struct {
// ndpConfigs is the default NDP configurations used by interfaces.
ndpConfigs NDPConfigurations
+ // nudConfigs is the default NUD configurations used by interfaces.
+ nudConfigs NUDConfigurations
+
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
// to auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
@@ -457,6 +457,10 @@ type Stack struct {
// integrator NDP related events.
ndpDisp NDPDispatcher
+ // nudDisp is the NUD event dispatcher that is used to send the netstack
+ // integrator NUD related events.
+ nudDisp NUDDispatcher
+
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
@@ -475,6 +479,14 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
randomGenerator *mathrand.Rand
+
+ // sendBufferSize holds the min/default/max send buffer sizes for
+ // endpoints other than TCP.
+ sendBufferSize SendBufferSizeOption
+
+ // receiveBufferSize holds the min/default/max receive buffer sizes for
+ // endpoints other than TCP.
+ receiveBufferSize ReceiveBufferSizeOption
}
// UniqueID is an abstract generator of unique identifiers.
@@ -513,6 +525,9 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
+ // NUDConfigs is the default NUD configurations used by interfaces.
+ NUDConfigs NUDConfigurations
+
// AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
// auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs.
@@ -531,6 +546,10 @@ type Options struct {
// receive NDP related events.
NDPDisp NDPDispatcher
+ // NUDDisp is the NUD event dispatcher that an integrator can provide to
+ // receive NUD related events.
+ NUDDisp NUDDispatcher
+
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
@@ -665,6 +684,8 @@ func New(opts Options) *Stack {
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
+ opts.NUDConfigs.resetInvalidFields()
+
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
@@ -676,16 +697,29 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
+ nudConfigs: opts.NUDConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ nudDisp: opts.NUDDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
tempIIDSeed: opts.TempIIDSeed,
forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ receiveBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
}
// Add specified network protocols.
@@ -712,6 +746,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -778,16 +817,17 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -1020,6 +1060,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
nic, ok := s.nics[id]
if !ok {
return tcpip.ErrUnknownNICID
@@ -1029,14 +1076,14 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
// Remove routes in-place. n tracks the number of routes written.
n := 0
for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
if r.NIC != id {
// Keep this route.
- if i > n {
- s.routeTable[n] = r
- }
+ s.routeTable[n] = r
n++
}
}
+
s.routeTable = s.routeTable[:n]
return nic.remove()
@@ -1072,6 +1119,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1103,6 +1155,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -1249,9 +1302,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isBroadcast := remoteAddr == header.IPv4Broadcast
+ isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
@@ -1272,9 +1325,16 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- if needRoute {
- r.NextHop = route.Gateway
+ r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
+
+ if len(route.Gateway) > 0 {
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ } else if r.directedBroadcast {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
+
return r, nil
}
}
@@ -1400,25 +1460,31 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
+// the stack transport dispatcher.
+func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
s.cleanupEndpoints[ep] = struct{}{}
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
@@ -1741,18 +1807,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() IPTables {
- s.tablesMu.RLock()
- t := s.tables
- s.tablesMu.RUnlock()
- return t
-}
-
-// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt IPTables) {
- s.tablesMu.Lock()
- s.tables = ipt
- s.tablesMu.Unlock()
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
@@ -1831,10 +1887,38 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip
}
nic.setNDPConfigs(c)
-
return nil
}
+// NUDConfigurations gets the per-interface NUD configurations.
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return NUDConfigurations{}, tcpip.ErrUnknownNICID
+ }
+
+ return nic.NUDConfigs()
+}
+
+// SetNUDConfigurations sets the per-interface NUD configurations.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.setNUDConfigs(c)
+}
+
// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
// message that it needs to handle.
func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
new file mode 100644
index 000000000..0b093e6c5
--- /dev/null
+++ b/pkg/tcpip/stack/stack_options.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4 KiB
+
+ // DefaultBufferSize is the default size of the send/recv buffer for a
+ // transport endpoint.
+ DefaultBufferSize = 212 << 10 // 212 KiB
+
+ // DefaultMaxBufferSize is the default maximum permitted size of a
+ // send/receive buffer.
+ DefaultMaxBufferSize = 4 << 20 // 4 MiB
+)
+
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// SetOption allows setting stack wide options.
+func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SendBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.sendBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.receiveBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option allows retrieving stack wide options.
+func (s *Stack) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SendBufferSizeOption:
+ s.mu.RLock()
+ *v = s.sendBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ s.mu.RLock()
+ *v = s.receiveBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 1a2cf007c..f22062889 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,6 +27,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -52,6 +53,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -90,30 +95,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fakeNetHeaderLen)
-
// Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -132,24 +135,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
+ pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
+ pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
+ pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ f.HandlePacket(r, pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -163,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -205,7 +203,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
@@ -247,6 +245,17 @@ func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -292,8 +301,8 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
- buf[0] = 3
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 3
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -304,8 +313,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
- buf[0] = 1
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 1
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -316,8 +325,8 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
- buf[0] = 2
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = 2
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -328,7 +337,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -340,7 +349,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -362,7 +371,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -420,7 +429,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -859,9 +868,9 @@ func TestRouteWithDownNIC(t *testing.T) {
// Writes with Routes that use NIC1 after being brought up should
// succeed.
//
- // TODO(b/147015577): Should we instead completely invalidate all
- // Routes that were bound to a NIC that was brought down at some
- // point?
+ // TODO(gvisor.dev/issue/1491): Should we instead completely
+ // invalidate all Routes that were bound to a NIC that was brought
+ // down at some point?
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
@@ -982,7 +991,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1032,7 +1041,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1114,7 +1123,7 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
if promiscuous {
if err := s.SetPromiscuousMode(nicID, true); err != nil {
@@ -1277,7 +1286,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1658,7 +1667,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -1766,7 +1775,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -2263,7 +2272,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2344,8 +2353,8 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
- buf[0] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ buf[dstAddrOffset] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -3297,7 +3306,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
@@ -3330,3 +3339,305 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
}
}
+
+func TestStackReceiveBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ rs stack.ReceiveBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.rs); err != tc.err {
+ t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
+ }
+ var rs stack.ReceiveBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&rs); err != nil {
+ t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
+ }
+ if got, want := rs, tc.rs; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestStackSendBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ ss stack.SendBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.ss); err != tc.err {
+ t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ }
+ var ss stack.SendBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&ss); err != nil {
+ t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
+ }
+ if got, want := ss, tc.ss; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID1 = 1
+ )
+
+ defaultAddr := tcpip.AddressWithPrefix{
+ Address: header.IPv4Any,
+ PrefixLen: 0,
+ }
+ defaultSubnet := defaultAddr.Subnet()
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedRoute stack.Route
+ }{
+ // Broadcast to a locally attached subnet populates the broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: ipv4SubnetBcast,
+ RemoteLinkAddress: header.EthernetBroadcastAddress,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /31 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix31.Address,
+ RemoteAddress: ipv4Subnet31Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /32 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix32.Address,
+ RemoteAddress: ipv4Subnet32Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv6Addr.Address,
+ RemoteAddress: ipv6SubnetBcast,
+ NetProto: header.IPv6ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a remote subnet in the route table is send to the next-hop
+ // gateway.
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to an unknown subnet follows the default route. Note that this
+ // is essentially just routing an unknown destination IP, because w/o any
+ // subnet prefix information a subnet broadcast address is just a normal IP.
+ {
+ name: "IPv4 Broadcast to unknown subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: defaultSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
+ } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
+ t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 9a33ed375..b902c6ca9 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -15,7 +15,6 @@
package stack
import (
- "container/heap"
"fmt"
"math/rand"
@@ -23,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
)
type protocolIDs struct {
@@ -43,14 +43,14 @@ type transportEndpoints struct {
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNIC, ok := eps.endpoints[id]
if !ok {
return
}
- if !epsByNIC.unregisterEndpoint(bindToDevice, ep) {
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
return
}
delete(eps.endpoints, id)
@@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
@@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
@@ -214,23 +214,34 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
demux: d,
netProto: netProto,
transProto: transProto,
- reuse: reusePort,
}
epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ return multiPortEp.singleRegisterEndpoint(t, flags)
+}
+
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ return nil
+ }
+
+ return multiPortEp.singleCheckEndpoint(flags)
}
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
-func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
return false
}
- if multiPortEp.unregisterEndpoint(t) {
+ if multiPortEp.unregisterEndpoint(t, flags) {
delete(epsByNIC.endpoints, bindToDevice)
}
return len(epsByNIC.endpoints) == 0
@@ -251,7 +262,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -279,10 +290,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
return err
}
}
@@ -290,33 +301,15 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
-type transportEndpointHeap []TransportEndpoint
-
-var _ heap.Interface = (*transportEndpointHeap)(nil)
-
-func (h *transportEndpointHeap) Len() int {
- return len(*h)
-}
-
-func (h *transportEndpointHeap) Less(i, j int) bool {
- return (*h)[i].UniqueID() < (*h)[j].UniqueID()
-}
-
-func (h *transportEndpointHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *transportEndpointHeap) Push(x interface{}) {
- *h = append(*h, x.(TransportEndpoint))
-}
+// checkEndpoint checks if an endpoint can be registered with the dispatcher.
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for _, n := range netProtos {
+ if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
+ return err
+ }
+ }
-func (h *transportEndpointHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- old[n-1] = nil
- *h = old[:n-1]
- return x
+ return nil
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
@@ -334,9 +327,10 @@ type multiPortEndpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- endpoints transportEndpointHeap
- // reuse indicates if more than one endpoint is allowed.
- reuse bool
+ // endpoints stores the transport endpoints in the order in which they
+ // were bound. This is required for UDP SO_REUSEADDR.
+ endpoints []TransportEndpoint
+ flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,6 +356,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[0]
}
+ if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ return mpep.endpoints[len(mpep.endpoints)-1]
+ }
+
payload := []byte{
byte(id.LocalPort),
byte(id.LocalPort >> 8),
@@ -379,7 +377,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
@@ -401,40 +399,63 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if !ep.reuse || !reusePort {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
return tcpip.ErrPortInUse
}
}
- heap.Push(&ep.endpoints, t)
+ ep.endpoints = append(ep.endpoints, t)
+ ep.flags.AddRef(bits)
+
+ return nil
+}
+
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
-func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
for i, endpoint := range ep.endpoints {
if endpoint == t {
- heap.Remove(&ep.endpoints, i)
+ copy(ep.endpoints[i:], ep.endpoints[i+1:])
+ ep.endpoints[len(ep.endpoints)-1] = nil
+ ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
+
+ ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
break
}
}
return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
- // TODO(eyalsoha): Why?
- reusePort = false
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
@@ -454,15 +475,42 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.endpoints[id] = epsByNIC
}
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+}
+
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return nil
+ }
+
+ return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep, bindToDevice)
+ eps.unregisterEndpoint(id, ep, flags, bindToDevice)
}
}
}
@@ -470,7 +518,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -520,7 +568,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -544,7 +592,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 2474a7db3..73dada928 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -127,7 +128,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -165,7 +166,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
NetworkHeader: buffer.View(ip),
TransportHeader: buffer.View(u),
@@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a611e44ab..7e8b84867 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -83,12 +84,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
+ hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -153,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -198,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false, /* reuse */
- 0, /* bindtoDevice */
+ ports.Flags{},
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -215,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -232,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -289,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -324,6 +326,17 @@ func (*fakeTransportProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeTransportProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
+ if !ok {
+ return false
+ }
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(fakeTransHeaderLen)
+ return true
+}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
@@ -369,7 +382,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -380,7 +393,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -391,7 +404,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -446,7 +459,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -457,7 +470,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -468,7 +481,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -623,7 +636,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: req.ToVectorisedView(),
})
@@ -642,11 +655,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- hdrs := p.Pkt.Data.ToView()
- if dst := hdrs[0]; dst != 3 {
+ if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := hdrs[1]; src != 1 {
+ if src := p.Pkt.NetworkHeader[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}