diff options
Diffstat (limited to 'pkg/tcpip/stack')
31 files changed, 9348 insertions, 1029 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f71073207..1c58bed2d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -16,6 +16,18 @@ go_template_instance( ) go_template_instance( + name = "neighbor_entry_list", + out = "neighbor_entry_list.go", + package = "stack", + prefix = "neighborEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*neighborEntry", + "Linker": "*neighborEntry", + }, +) + +go_template_instance( name = "packet_buffer_list", out = "packet_buffer_list.go", package = "stack", @@ -27,6 +39,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,12 +59,18 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", + "neighbor_cache.go", + "neighbor_entry.go", + "neighbor_entry_list.go", + "neighborstate_string.go", "nic.go", + "nud.go", "packet_buffer.go", "packet_buffer_list.go", "rand.go", @@ -48,7 +78,9 @@ go_library( "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -74,10 +106,12 @@ go_test( size = "medium", srcs = [ "ndp_test.go", + "nud_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", @@ -89,10 +123,12 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) @@ -100,8 +136,11 @@ go_test( name = "stack_test", size = "small", srcs = [ + "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", + "neighbor_cache_test.go", + "neighbor_entry_test.go", "nic_test.go", ], library = ":stack", @@ -110,5 +149,9 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "@com_github_dpjacques_clockwork//:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..470c265aa 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -20,332 +20,330 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) // Connection tracking is used to track and manipulate packets for NAT rules. -// The connection is created for a packet if it does not exist. Every connection -// contains two tuples (original and reply). The tuples are manipulated if there -// is a matching NAT rule. The packet is modified by looking at the tuples in the -// Prerouting and Output hooks. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. + +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 // Direction of the tuple. -type ctDirection int +type direction int const ( - dirOriginal ctDirection = iota + dirOriginal direction = iota dirReply ) -// Status of connection. -// TODO(gvisor.dev/issue/170): Add other states of connection. -type connStatus int - -const ( - connNew connStatus = iota - connEstablished -) - // Manipulation type for the connection. type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) -// connTrackMutable is the manipulatable part of the tuple. -type connTrackMutable struct { - // addr is source address of the tuple. - addr tcpip.Address - - // port is source port of the tuple. - port uint16 - - // protocol is network layer protocol. - protocol tcpip.NetworkProtocolNumber -} - -// connTrackImmutable is the non-manipulatable part of the tuple. -type connTrackImmutable struct { - // addr is destination address of the tuple. - addr tcpip.Address +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +// +// +stateify savable +type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry - // direction is direction (original or reply) of the tuple. - direction ctDirection + tupleID - // port is destination port of the tuple. - port uint16 + // conn is the connection tracking entry this tuple belongs to. + conn *conn - // protocol is transport layer protocol. - protocol tcpip.TransportProtocolNumber + // direction is the direction of the tuple. + direction direction } -// connTrackTuple represents the tuple which is created from the -// packet. -type connTrackTuple struct { - // dst is non-manipulatable part of the tuple. - dst connTrackImmutable - - // src is manipulatable part of the tuple. - src connTrackMutable +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +// +// +stateify savable +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber } -// connTrackTupleHolder is the container of tuple and connection. -type ConnTrackTupleHolder struct { - // conn is pointer to the connection tracking entry. - conn *connTrack - - // tuple is original or reply tuple. - tuple connTrackTuple +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } } -// connTrack is the connection. -type connTrack struct { - // originalTupleHolder contains tuple in original direction. - originalTupleHolder ConnTrackTupleHolder - - // replyTupleHolder contains tuple in reply direction. - replyTupleHolder ConnTrackTupleHolder - - // status indicates connection is new or established. - status connStatus +// conn is a tracked connection. +// +// +stateify savable +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple - // timeout indicates the time connection should be active. - timeout time.Duration + // reply is the tuple in reply direction. It is immutable. + reply tuple - // manip indicates if the packet should be manipulated. + // manip indicates if the packet should be manipulated. It is immutable. manip manipType - // tcb is TCB control block. It is used to keep track of states - // of tcp connection. - tcb tcpconntrack.TCB - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. + // update the state of tcb. It is immutable. tcbHook Hook -} -// ConnTrackTable contains a map of all existing connections created for -// NAT rules. -type ConnTrackTable struct { - // connMu protects connTrackTable. - connMu sync.RWMutex + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` + // tcb is TCB control block. It is used to keep track of states + // of tcp connection and is protected by mu. + tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} - // connTrackTable maintains a map of tuples needed for connection tracking - // for iptables NAT rules. The key for the map is an integer calculated - // using seed, source address, destination address, source port and - // destination port. - CtMap map[uint32]ConnTrackTupleHolder +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} - // seed is a one-time random value initialized at stack startup - // and is used in calculation of hash key for connection tracking - // table. - Seed uint32 +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } } -// parseHeaders sets headers in the packet. -func parseHeaders(pkt *PacketBuffer) { - newPkt := pkt.Clone() +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable +type ConnTrack struct { + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // Set network header. - hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - netHeader := header.IPv4(hdr) - newPkt.NetworkHeader = hdr - length := int(netHeader.HeaderLength()) + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - // Set transport header. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) - } - case header.TCPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) - } - } - pkt.NetworkHeader = newPkt.NetworkHeader - pkt.TransportHeader = newPkt.TransportHeader + // buckets is protected by mu. + buckets []bucket } -// packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { - var tuple connTrackTuple +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList +} - netHeader := header.IPv4(pkt.NetworkHeader) +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader) if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } - tuple.src.addr = netHeader.SourceAddress() - tuple.src.port = tcpHeader.SourcePort() - tuple.src.protocol = header.IPv4ProtocolNumber - - tuple.dst.addr = netHeader.DestinationAddress() - tuple.dst.port = tcpHeader.DestinationPort() - tuple.dst.protocol = netHeader.TransportProtocol() - - return tuple, nil -} - -// getReplyTuple creates reply tuple for the given tuple. -func getReplyTuple(tuple connTrackTuple) connTrackTuple { - var replyTuple connTrackTuple - replyTuple.src.addr = tuple.dst.addr - replyTuple.src.port = tuple.dst.port - replyTuple.src.protocol = tuple.src.protocol - replyTuple.dst.addr = tuple.src.addr - replyTuple.dst.port = tuple.src.port - replyTuple.dst.protocol = tuple.dst.protocol - replyTuple.dst.direction = dirReply - - return replyTuple + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil } -// makeNewConn creates new connection. -func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { - var conn connTrack - conn.status = connNew - conn.originalTupleHolder.tuple = tuple - conn.originalTupleHolder.conn = &conn - conn.replyTupleHolder.tuple = replyTuple - conn.replyTupleHolder.conn = &conn - - return conn +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn } -// getTupleHash returns hash of the tuple. The fields used for -// generating hash are seed (generated once for stack), source address, -// destination address, source port and destination ports. -func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { - h := jenkins.Sum32(ct.Seed) - h.Write([]byte(tuple.src.addr)) - h.Write([]byte(tuple.dst.addr)) - portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, tuple.src.port) - h.Write([]byte(portBuf)) - binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) - h.Write([]byte(portBuf)) - - return h.Sum32() +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil, dirOriginal + } + return ct.connForTID(tid) } -// connTrackForPacket returns connTrack for packet. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other -// transport protocols. -func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - if hook == Prerouting { - // Headers will not be set in Prerouting. - // TODO(gvisor.dev/issue/170): Change this after parsing headers - // code is added. - parseHeaders(pkt) +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) - if err != nil { - return nil, dir - } - - ct.connMu.Lock() - defer ct.connMu.Unlock() - - connTrackTable := ct.CtMap - hash := ct.getTupleHash(tuple) - - var conn *connTrack - switch createConn { - case true: - // If connection does not exist for the hash, create a new - // connection. - replyTuple := getReplyTuple(tuple) - replyHash := ct.getTupleHash(replyTuple) - newConn := makeNewConn(tuple, replyTuple) - conn = &newConn - - // Add tupleHolders to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. - ct.CtMap[hash] = conn.originalTupleHolder - ct.CtMap[replyHash] = conn.replyTupleHolder - default: - tupleHolder, ok := connTrackTable[hash] - if !ok { - return nil, dir - } + return nil, dirOriginal +} - // If this is the reply of new connection, set the connection - // status as ESTABLISHED. - conn = tupleHolder.conn - if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { - conn.status = connEstablished - } - if tupleHolder.conn == nil { - panic("tupleHolder has null connection tracking entry") - } +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Prerouting && hook != Output { + return nil + } - dir = tupleHolder.tuple.dst.direction + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput } - return conn, dir + conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn } -// SetNatInfo will manipulate the tuples according to iptables NAT rules. -func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { - // Get the connection. Connection is always created before this - // function is called. - conn, _ := ct.connTrackForPacket(pkt, hook, false) - if conn == nil { - panic("connection should be created to manipulate tuples.") +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() } - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to - // support changing of address for Prerouting. - - // Change the port as per the iptables rule. This tuple will be used - // to manipulate the packet in HandlePacket. - conn.replyTupleHolder.tuple.src.addr = rt.MinIP - conn.replyTupleHolder.tuple.src.port = rt.MinPort - newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } - // Add the changed tuple to the map. - ct.connMu.Lock() - defer ct.connMu.Unlock() - ct.CtMap[newHash] = conn.replyTupleHolder - if hook == Output { - conn.replyTupleHolder.conn.manip = manipDstOutput + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) } - // Delete the old tuple. - delete(ct.CtMap, replyHash) + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. -func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -354,21 +352,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) // modified. switch dir { case dirOriginal: - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) case dirReply: - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -377,13 +385,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, // modified. For prerouting redirection, we only reach this point // when replying, so packet sources are modified. if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) } else { - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } // Calculate the TCP checksum and set it. @@ -402,33 +410,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, netHeader.SetChecksum(^netHeader.CalculateChecksum()) } -// HandlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// handlePacket will manipulate the port and address of the packet if the +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false } - conn, dir := ct.connTrackForPacket(pkt, hook, false) - // Connection or Rule not found for the packet. - if conn == nil { - return + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Need to support for other transport - // protocols as well. - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return + conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. + if conn == nil { + return true } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return + return false } switch hook { @@ -442,39 +449,184 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle // other tcp states. - var st tcpconntrack.Result - if conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.mu.Lock() + defer conn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return } - // Delete conntrack if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConnTrack(conn) + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return } -} -// deleteConnTrack deletes the connection. -func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { - if conn == nil { + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} + +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} - tuple := conn.originalTupleHolder.tuple - hash := ct.getTupleHash(tuple) - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } - ct.connMu.Lock() - defer ct.connMu.Unlock() + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } + + return true +} + +func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: header.IPv4ProtocolNumber, + } + conn, _ := ct.connForTID(tid) + if conn == nil { + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue + } - delete(ct.CtMap, hash) - delete(ct.CtMap, replyHash) + return conn.original.dstAddr, conn.original.dstPort, nil } diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go new file mode 100644 index 000000000..92c8cb534 --- /dev/null +++ b/pkg/tcpip/stack/fake_time_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +type fakeClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +func newFakeClock() *fakeClock { + return &fakeClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*fakeClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (fc *fakeClock) NowNanoseconds() int64 { + return fc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (fc *fakeClock) NowMonotonic() int64 { + return fc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := fc.clock.Now().Add(d) + wg := fc.addWait(until) + return &fakeTimer{ + clock: fc, + until: until, + timer: fc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { + fc.mu.RLock() + wg, ok := fc.waitGroups[t] + fc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + fc.mu.Lock() + heap.Push(fc.times, t) + fc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + fc.mu.Lock() + fc.waitGroups[t] = wg + fc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (fc *fakeClock) removeWait(t time.Time) { + fc.mu.RLock() + defer fc.mu.RUnlock() + + wg := fc.waitGroups[t] + wg.Done() +} + +// advance executes all work that have been scheduled to execute within d from +// the current fake time. Blocks until all work has completed execution. +func (fc *fakeClock) advance(d time.Duration) { + // Block until all the work is done + until := fc.clock.Now().Add(d) + for { + fc.mu.Lock() + if fc.times.Len() == 0 { + fc.mu.Unlock() + return + } + + t := heap.Pop(fc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(fc.times, t) + fc.mu.Unlock() + return + } + fc.mu.Unlock() + + diff := t.Sub(fc.clock.Now()) + fc.clock.Advance(diff) + + fc.mu.RLock() + wg := fc.waitGroups[t] + fc.mu.RUnlock() + + wg.Wait() + + fc.mu.Lock() + delete(fc.waitGroups, t) + fc.mu.Unlock() + } +} + +type fakeTimer struct { + clock *fakeClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*fakeTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (ft *fakeTimer) Reset(d time.Duration) { + if !ft.timer.Reset(d) { + return + } + + ft.mu.Lock() + defer ft.mu.Unlock() + + ft.clock.removeWait(ft.until) + ft.until = ft.clock.clock.Now().Add(d) + ft.clock.addWait(ft.until) +} + +// Stop implements tcpip.Timer.Stop. +func (ft *fakeTimer) Stop() bool { + if !ft.timer.Stop() { + return false + } + + ft.mu.RLock() + defer ft.mu.RUnlock() + + ft.clock.removeWait(ft.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 8084d50bc..c962693f5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -33,6 +34,10 @@ const ( // except where another value is explicitly used. It is chosen to match // the MTU of loopback interfaces on linux systems. fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -68,16 +73,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { - // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fwdTestNetHeaderLen) - +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -96,13 +94,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = f.id.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } @@ -112,7 +110,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -123,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {} type fwdTestNetworkProtocol struct { addrCache *linkAddrCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) } +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -140,7 +140,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { } func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = netHeader + pkt.Data.TrimFront(fwdTestNetHeaderLen) + return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -166,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {} func (f *fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { if f.addrCache != nil && f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) + f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) }) } return nil @@ -190,7 +200,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,13 +213,13 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -251,7 +261,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +280,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +290,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -294,6 +304,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ @@ -361,8 +381,8 @@ func TestForwardingWithStaticResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -387,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -398,8 +418,8 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -429,8 +449,8 @@ func TestForwardingWithNoResolver(t *testing.T) { // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -445,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { @@ -459,16 +479,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. buf := buffer.NewView(30) - buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -480,9 +500,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -498,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -509,8 +528,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -524,9 +543,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -543,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -554,10 +572,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -571,14 +589,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. + seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). + if !ok { + t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) + } + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { t.Fatalf("got the packet #%d, want = #%d", n, want) } @@ -596,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -609,8 +631,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Each packet has a different destination address (3 to // maxPendingResolutions + 7). buf := buffer.NewView(30) - buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -626,9 +648,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 443423b3c..110ba073d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,40 +16,49 @@ package stack import ( "fmt" - "strings" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) -// Table names. +// tableID is an index into IPTables.tables. +type tableID int + const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" + natID tableID = iota + mangleID + filterID + numTables ) -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +// Table names. const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" ) +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. - return IPTables{ - Tables: map[string]Table{ - TablenameNat: Table{ +func DefaultTables() *IPTables { + return &IPTables{ + tables: [numTables]Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -57,65 +66,71 @@ func DefaultTables() IPTables { Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - UserChains: map[string]int{}, }, - TablenameMangle: Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, - TablenameFilter: Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, - connections: ConnTrackTable{ - CtMap: make(map[uint32]ConnTrackTupleHolder), - Seed: generateRandUint32(), + connections: ConnTrack{ + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -124,41 +139,61 @@ func DefaultTables() IPTables { func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, } } -// EmptyNatTable returns a Table with no rules and the filter table chains +// EmptyNATTable returns a Table with no rules and the filter table chains // mapped to HookUnset. -func EmptyNatTable() Table { +func EmptyNATTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + Underflows: [NumHooks]int{ + Forward: HookUnset, }, - UserChains: map[string]int{}, } } +// GetTable returns a table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } + it.mu.RLock() + defer it.mu.RUnlock() + return it.tables[id], true +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } + it.mu.Lock() + defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } + it.modified = true + it.tables[id] = table + return nil +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -180,13 +215,27 @@ const ( // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + it.mu.RLock() + defer it.mu.RUnlock() + if !it.modified { + return true + } + // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.HandlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -215,17 +264,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -249,9 +340,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -296,25 +387,14 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -322,7 +402,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } @@ -336,46 +416,8 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } -func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(filter.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which begins - // with the name should be matched. - ifName := filter.OutputInterface - matches = true - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName - } - return filter.OutputInterfaceInvert != matches - } - - return true +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + return it.connections.originalDst(epID) } diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 36cc6275d..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -92,17 +92,12 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 } - // Set network header. - if hook == Prerouting { - parseHeaders(pkt) - } - // Drop the packet if network and transport header are not set. if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { return RuleDrop, 0 @@ -155,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set up conection for matching NAT rule. - // Only the first packet of the connection comes here. - // Other packets will be manipulated in connection tracking. - if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { - ct.SetNatInfo(pkt, rt, hook) - ct.HandlePacket(pkt, hook, gso, r) + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fe06007ae..73274ada9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,7 +15,11 @@ package stack import ( + "strings" + "sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // A Hook specifies one of the hooks built into the network stack. @@ -74,63 +78,65 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables, priorities, and modified. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps tableIDs to tables. Holds builtin tables only, not user + // tables. mu must be locked for accessing. + tables [numTables]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities [NumHooks][]tableID + + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + modified bool - connections ConnTrackTable + connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. +// really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int + BuiltinChains [NumHooks]int // Underflows maps builtin chains to their underflow rule in Rules // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} + Underflows [NumHooks]int } // ValidHooks returns a bitmap of the builtin hooks for the given table. func (table *Table) ValidHooks() uint32 { hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } } return hooks } -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -143,6 +149,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber @@ -159,6 +167,16 @@ type IPHeaderFilter struct { // comparison. DstInvert bool + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + // OutputInterface matches the name of the outgoing interface for the // packet. OutputInterface string @@ -173,6 +191,55 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + // A Matcher is the interface for matching packets. type Matcher interface { // Name returns the name of the Matcher. @@ -183,7 +250,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. @@ -191,5 +258,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 403557fd7..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 1baa498d0..b15b8d1cb 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..5174e639c 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -33,12 +33,6 @@ const ( // Default = 1 (from RFC 4862 section 5.1) defaultDupAddrDetectTransmits = 1 - // defaultRetransmitTimer is the default amount of time to wait between - // sending NDP Neighbor solicitation messages. - // - // Default = 1s (from RFC 4861 section 10). - defaultRetransmitTimer = time.Second - // defaultMaxRtrSolicitations is the default number of Router // Solicitation messages to send when a NIC becomes enabled. // @@ -79,16 +73,6 @@ const ( // Default = true. defaultAutoGenGlobalAddresses = true - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending NDP Neighbor solicitation messages. Note, RFC 4861 does - // not impose a minimum Retransmit Timer, but we do here to make sure - // the messages are not sent all at once. We also come to this value - // because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. - // Note, the unit of the RetransmitTimer field in the Router - // Advertisement is milliseconds. - minimumRetransmitTimer = time.Millisecond - // minimumRtrSolicitationInterval is the minimum amount of time to wait // between sending Router Solicitation messages. This limit is imposed // to make sure that Router Solicitation messages are not sent all at @@ -467,8 +451,17 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The timer used to send the next router solicitation message. - rtrSolicitTimer *time.Timer + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer tcpip.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -494,7 +487,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -506,38 +499,38 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time @@ -552,15 +545,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -642,19 +635,20 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work // cannot be done while holding the NIC's lock. This is effectively the same // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { - ndp.nic.mu.RLock() + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if done { // If we reach this point, it means that the DAD timer fired after // another goroutine already obtained the NIC lock and stopped DAD // before this function obtained the NIC lock. Simply return here and do // nothing further. - ndp.nic.mu.RUnlock() return } @@ -665,15 +659,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } dadDone := remaining == 0 - ndp.nic.mu.RUnlock() var err *tcpip.Error if !dadDone { - err = ndp.sendDADPacket(addr) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() } - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. @@ -721,17 +723,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // addr. // // addr must be a tentative IPv6 address on ndp's NIC. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { +// +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) @@ -750,7 +759,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -846,9 +855,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -925,7 +934,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -954,12 +963,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -984,13 +993,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1008,7 +1017,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -1057,14 +1066,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1129,7 +1138,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1137,7 +1146,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1159,19 +1168,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } @@ -1403,7 +1412,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1416,7 +1425,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1429,7 +1438,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1456,9 +1465,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ref: ref, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1493,16 +1502,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat prefixState.stableAddr.ref.deprecated = false } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1521,9 +1530,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1544,8 +1553,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } @@ -1557,7 +1566,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1571,14 +1580,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1591,17 +1600,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1612,7 +1621,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1692,7 +1701,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // @@ -1704,8 +1713,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa } state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } @@ -1750,13 +1759,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } @@ -1816,7 +1825,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitTimer != nil { + if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. return } @@ -1833,14 +1842,27 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) } + ndp.nic.mu.Unlock() + localAddr := ref.ep.ID().LocalAddress r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() @@ -1849,6 +1871,13 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) @@ -1881,7 +1910,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) @@ -1893,17 +1922,18 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - if remaining == 0 { - ndp.rtrSolicitTimer = nil - } else if ndp.rtrSolicitTimer != nil { + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but // we still reached this point, then we know the NIC // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. - ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } + ndp.nic.mu.Unlock() }) } @@ -1913,13 +1943,15 @@ func (ndp *ndpState) startSolicitingRouters() { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitTimer == nil { + if ndp.rtrSolicit.timer == nil { // Nothing to do. return } - ndp.rtrSolicitTimer.Stop() - ndp.rtrSolicitTimer = nil + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil } // initializeTempAddrState initializes state related to temporary SLAAC diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..644ba7c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -36,15 +36,24 @@ import ( ) const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - defaultTimeout = 100 * time.Millisecond - defaultAsyncEventTimeout = time.Second + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 10 * time.Second + + // Extra time to use when waiting for an async event to not occur. + // + // Since a negative check is used to make sure an event did not happen, it is + // okay to use a smaller timeout compared to the positive case since execution + // stall in regards to the monotonic clock will not affect the expected + // outcome. + defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -421,45 +430,90 @@ func TestDADResolve(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + // We add a default route so the call to FindRoute below will succeed + // once we have an assigned address. + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: addr3, + NIC: nicID, + }}) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + // Should not get a route even if we specify the local address as the + // tentative address. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() + } + + if t.Failed() { + t.FailNow() } // Wait for DAD to resolve. select { - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if addr.Address != addr1 { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + // Should get a route using the address now that it is resolved. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) + } + r.Release() + } + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) + } + r.Release() + } + + if t.Failed() { + t.FailNow() } // Should not have sent any more NS messages. @@ -613,7 +667,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +989,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +1024,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +1040,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +1048,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1057,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. @@ -1124,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1200,14 +1254,14 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1217,14 +1271,14 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1373,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1448,14 +1502,14 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for prefix discovery event") } @@ -1520,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with finite lifetime. @@ -1545,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with a prefix with a lifetime value greater than the @@ -1554,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): } // Receive an RA with 0 lifetime. @@ -1790,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1917,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -1930,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -2036,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { @@ -2135,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -2143,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -2220,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } select { @@ -2228,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2318,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2341,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2353,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { @@ -2362,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event = %+v", e) } - case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2378,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2472,14 +2526,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Prefer the prefix again. // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2501,24 +2555,24 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2666,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2679,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -2939,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second @@ -3025,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3065,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3079,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3111,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { @@ -3120,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3250,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): + case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3394,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): } // Wait for the invalidation event. @@ -3403,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3459,12 +3513,12 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -3627,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -3725,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3792,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -3818,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -3985,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4104,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4206,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -4232,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4240,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for auto gen addr event") } } @@ -4824,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) select { case <-ndpDisp.routerC: t.Error("unexpected router event") @@ -5127,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) remaining-- } for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout) - waitForPkt(2 * defaultAsyncEventTimeout) + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) + waitForPkt(defaultAsyncPositiveEventTimeout) } else { - waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) } } // Make sure no more RS. if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout) + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) } // Make sure the counter got properly @@ -5260,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok = e.ReadContext(ctx); ok { t.Fatal("should not have sent more than one RS message") @@ -5274,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") @@ -5287,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + waitForPkt(delay + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") @@ -5299,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go new file mode 100644 index 000000000..1d37716c2 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -0,0 +1,335 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const neighborCacheSize = 512 // max entries per interface + +// neighborCache maps IP addresses to link addresses. It uses the Least +// Recently Used (LRU) eviction strategy to implement a bounded cache for +// dynmically acquired entries. It contains the state machine and configuration +// for running Neighbor Unreachability Detection (NUD). +// +// There are two types of entries in the neighbor cache: +// 1. Dynamic entries are discovered automatically by neighbor discovery +// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm +// reachability with the device once the entry's state becomes Stale. +// 2. Static entries are explicitly added by a user and have no expiration. +// Their state is always Static. The amount of static entries stored in the +// cache is unbounded. +// +// neighborCache implements NUDHandler. +type neighborCache struct { + nic *NIC + state *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList + + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } +} + +var _ NUDHandler = (*neighborCache)(nil) + +// getOrCreateEntry retrieves a cache entry associated with addr. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[remoteAddr]; ok { + entry.mu.RLock() + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.lru.PushFront(entry) + } + entry.mu.RUnlock() + return entry + } + + // The entry that needs to be created must be dynamic since all static + // entries are directly added to the cache via addStaticEntry. + entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + if n.dynamic.count == neighborCacheSize { + e := n.dynamic.lru.Back() + e.mu.Lock() + + delete(n.cache, e.neigh.Addr) + n.dynamic.lru.Remove(e) + n.dynamic.count-- + + e.dispatchRemoveEventLocked() + e.setStateLocked(Unknown) + e.notifyWakersLocked() + e.mu.Unlock() + } + n.cache[remoteAddr] = entry + n.dynamic.lru.PushFront(entry) + n.dynamic.count++ + return entry +} + +// entry looks up the neighbor cache for translating address to link address +// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there +// is a LinkAddressResolver registered with the network protocol, the cache +// attempts to resolve the address and returns ErrWouldBlock. If a Waker is +// provided, it will be notified when address resolution is complete (success +// or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification +// channel is returned for the top level caller to block. Channel is closed +// once address resolution is complete (success or not). +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + if linkRes != nil { + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } + return e, nil, nil + } + } + + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + defer entry.mu.Unlock() + + switch s := entry.neigh.State; s { + case Reachable, Static: + return entry.neigh, nil, nil + + case Unknown, Incomplete, Stale, Delay, Probe: + entry.addWakerLocked(w) + + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.neigh, nil, tcpip.ErrNoLinkAddress + } + entry.done = make(chan struct{}) + } + + entry.handlePacketQueuedLocked() + return entry.neigh, entry.done, tcpip.ErrWouldBlock + + case Failed: + return entry.neigh, nil, tcpip.ErrNoLinkAddress + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", s)) + } +} + +// removeWaker removes a waker that has been added when link resolution for +// addr was requested. +func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { + n.mu.Lock() + if entry, ok := n.cache[addr]; ok { + delete(entry.wakers, waker) + } + n.mu.Unlock() +} + +// entries returns all entries in the neighbor cache. +func (n *neighborCache) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(n.cache)) + n.mu.RLock() + for _, entry := range n.cache { + entry.mu.RLock() + entries = append(entries, entry.neigh) + entry.mu.RUnlock() + } + n.mu.RUnlock() + return entries +} + +// addStaticEntry adds a static entry to the neighbor cache, mapping an IP +// address to a link address. If a dynamic entry exists in the neighbor cache +// with the same address, it will be replaced with this static entry. If a +// static entry exists with the same address but different link address, it +// will be updated with the new link address. If a static entry exists with the +// same address and link address, nothing will happen. +func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[addr]; ok { + entry.mu.Lock() + if entry.neigh.State != Static { + // Dynamic entry found with the same address. + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } else if entry.neigh.LinkAddr == linkAddr { + // Static entry found with the same address and link address. + entry.mu.Unlock() + return + } else { + // Static entry found with the same address but different link address. + entry.neigh.LinkAddr = linkAddr + entry.dispatchChangeEventLocked(entry.neigh.State) + entry.mu.Unlock() + return + } + + // Notify that resolution has been interrupted, just in case the entry was + // in the Incomplete or Probe state. + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = entry +} + +// removeEntryLocked removes the specified entry from the neighbor cache. +func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + if entry.neigh.State != Failed { + entry.dispatchRemoveEventLocked() + } + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + + delete(n.cache, entry.neigh.Addr) +} + +// removeEntry removes a dynamic or static entry by address from the neighbor +// cache. Returns true if the entry was found and deleted. +func (n *neighborCache) removeEntry(addr tcpip.Address) bool { + n.mu.Lock() + defer n.mu.Unlock() + + entry, ok := n.cache[addr] + if !ok { + return false + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + n.removeEntryLocked(entry) + return true +} + +// clear removes all dynamic and static entries from the neighbor cache. +func (n *neighborCache) clear() { + n.mu.Lock() + defer n.mu.Unlock() + + for _, entry := range n.cache { + entry.mu.Lock() + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + n.dynamic.lru = neighborEntryList{} + n.cache = make(map[tcpip.Address]*neighborEntry) + n.dynamic.count = 0 +} + +// config returns the NUD configuration. +func (n *neighborCache) config() NUDConfigurations { + return n.state.Config() +} + +// setConfig changes the NUD configuration. +// +// If config contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *neighborCache) setConfig(config NUDConfigurations) { + config.resetInvalidFields() + n.state.SetConfig(config) +} + +// HandleProbe implements NUDHandler.HandleProbe by following the logic defined +// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled +// by the caller. +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) + entry.mu.Lock() + entry.handleProbeLocked(remoteLinkAddr) + entry.mu.Unlock() +} + +// HandleConfirmation implements NUDHandler.HandleConfirmation by following the +// logic defined in RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce cryptographically +// generated addresses, as defined in RFC 3972, Cryptographically Generated +// Addresses (CGA). This ensures that the claimed source of an NDP message is +// the owner of the claimed address. +func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleConfirmationLocked(linkAddr, flags) + entry.mu.Unlock() + } + // The confirmation SHOULD be silently discarded if the recipient did not + // initiate any communication with the target. This is indicated if there is + // no matching entry for the remote address. +} + +// HandleUpperLevelConfirmation implements +// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in +// RFC 4861 section 7.3.1. +func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleUpperLevelConfirmationLocked() + entry.mu.Unlock() + } +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go new file mode 100644 index 000000000..4cb2c9c6b --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -0,0 +1,1752 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // entryStoreSize is the default number of entries that will be generated and + // added to the entry store. This number needs to be larger than the size of + // the neighbor cache to give ample opportunity for verifying behavior during + // cache overflows. Four times the size of the neighbor cache allows for + // three complete cache overflows. + entryStoreSize = 4 * neighborCacheSize + + // typicalLatency is the typical latency for an ARP or NDP packet to travel + // to a router and back. + typicalLatency = time.Millisecond + + // testEntryBroadcastAddr is a special address that indicates a packet should + // be sent to all nodes. + testEntryBroadcastAddr = tcpip.Address("broadcast") + + // testEntryLocalAddr is the source address of neighbor probes. + testEntryLocalAddr = tcpip.Address("local_addr") + + // testEntryBroadcastLinkAddr is a special link address sent back to + // multicast neighbor probes. + testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") + + // infiniteDuration indicates that a task will not occur in our lifetime. + infiniteDuration = time.Duration(math.MaxInt64) +) + +// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor +// entries. The UpdatedAt field is ignored due to a lack of a deterministic +// method to predict the time that an event will be dispatched. +func entryDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + } +} + +// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to +// sort slices of entries for cases where ordering must be ignored. +func entryDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { + config.resetInvalidFields() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return &neighborCache{ + nic: &NIC{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + }, + id: 1, + }, + state: NewNUDState(config, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } +} + +// testEntryStore contains a set of IP to NeighborEntry mappings. +type testEntryStore struct { + mu sync.RWMutex + entriesMap map[tcpip.Address]NeighborEntry +} + +func toAddress(i int) tcpip.Address { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint16(i)) + return tcpip.Address(buf.String()) +} + +func toLinkAddress(i int) tcpip.LinkAddress { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint32(i)) + return tcpip.LinkAddress(buf.String()) +} + +// newTestEntryStore returns a testEntryStore pre-populated with entries. +func newTestEntryStore() *testEntryStore { + store := &testEntryStore{ + entriesMap: make(map[tcpip.Address]NeighborEntry), + } + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + linkAddr := toLinkAddress(i) + + store.entriesMap[addr] = NeighborEntry{ + Addr: addr, + LocalAddr: testEntryLocalAddr, + LinkAddr: linkAddr, + } + } + return store +} + +// size returns the number of entries in the store. +func (s *testEntryStore) size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entriesMap) +} + +// entry returns the entry at index i. Returns an empty entry and false if i is +// out of bounds. +func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { + return s.entryByAddr(toAddress(i)) +} + +// entryByAddr returns the entry matching addr for situations when the index is +// not available. Returns an empty entry and false if no entries match addr. +func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.entriesMap[addr] + return entry, ok +} + +// entries returns all entries in the store. +func (s *testEntryStore) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(s.entriesMap)) + s.mu.RLock() + defer s.mu.RUnlock() + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + if entry, ok := s.entriesMap[addr]; ok { + entries = append(entries, entry) + } + } + return entries +} + +// set modifies the link addresses of an entry. +func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { + addr := toAddress(i) + s.mu.Lock() + defer s.mu.Unlock() + if entry, ok := s.entriesMap[addr]; ok { + entry.LinkAddr = linkAddr + s.entriesMap[addr] = entry + } +} + +// testNeighborResolver implements LinkAddressResolver to emulate sending a +// neighbor probe. +type testNeighborResolver struct { + clock tcpip.Clock + neigh *neighborCache + entries *testEntryStore + delay time.Duration + onLinkAddressRequest func() +} + +var _ LinkAddressResolver = (*testNeighborResolver)(nil) + +func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(addr) + }) + + // Execute post address resolution action, if available. + if f := r.onLinkAddressRequest; f != nil { + f() + } + return nil +} + +// fakeRequest emulates handling a response for a link address request. +func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { + if entry, ok := r.entries.entryByAddr(addr); ok { + r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } +} + +func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == testEntryBroadcastAddr { + return testEntryBroadcastLinkAddr, true + } + return "", false +} + +func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 0 +} + +type entryEvent struct { + nicID tcpip.NICID + address tcpip.Address + linkAddr tcpip.LinkAddress + state NeighborState +} + +func TestNeighborCacheGetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheSetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + neigh.setConfig(c) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a +// LinkAddressResolver returns ErrNoLinkAddress. +func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) + if err != tcpip.ErrNoLinkAddress { + t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveEntry(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } +} + +type testContext struct { + clock *fakeClock + neigh *neighborCache + store *testEntryStore + linkRes *testNeighborResolver + nudDisp *testNUDDispatcher +} + +func newTestContext(c NUDConfigurations) testContext { + nudDisp := &testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + return testContext{ + clock: clock, + neigh: neigh, + store: store, + linkRes: linkRes, + nudDisp: nudDisp, + } +} + +type overflowOptions struct { + startAtEntryIndex int + wantStaticEntries []NeighborEntry +} + +func (c *testContext) overflowCache(opts overflowOptions) error { + // Fill the neighbor cache to capacity to verify the LRU eviction strategy is + // working properly after the entry removal. + for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + // Add a new entry + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + + var wantEvents []testEntryEventInfo + + // When beyond the full capacity, the cache will evict an entry as per the + // LRU eviction strategy. Note that the number of static entries should not + // affect the total number of dynamic entries that can be added. + if i >= neighborCacheSize+opts.startAtEntryIndex { + removedEntry, ok := c.store.entry(i - neighborCacheSize) + if !ok { + return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + } + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, testEntryEventInfo{ + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }) + + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the most recent entries. The order of entries reported + // by entries() is undeterministic, so entries have to be sorted before + // comparison. + wantUnsortedEntries := opts.wantStaticEntries + for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + return nil +} + +// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy +// respects the dynamic entry count. +func TestNeighborCacheOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when an entry is removed. +func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the entry + c.neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that +// adding a duplicate static entry with the same link address does not dispatch +// any events. +func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that +// adding a duplicate static entry with a different link address dispatches a +// change event. +func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a duplicate entry with a different link address + staticLinkAddr += "duplicate" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when a static entry is +// added then removed. In this case, the dynamic entry count shouldn't have +// been touched. +func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.removeEntry(entry.Addr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU +// cache eviction strategy keeps count of the dynamic entry count when an entry +// is overwritten by a static entry. Static entries should not count towards +// the size of the LRU cache. +func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Override the entry with a static one using the same address + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: staticLinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheNotifiesWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + id, ok := s.Fetch(false /* block */) + if !ok { + t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + if id != wakerID { + t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + + // Remove the waker before the neighbor cache has the opportunity to send a + // notification. + neigh.removeWaker(entry.Addr, &w) + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Errorf("unexpected notification from waker with id %d", id) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + if err != nil { + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheClear(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add a dynamic entry. + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a static entry. + neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Clear shoud remove both dynamic and static entries. + neigh.clear() + + // Remove events dispatched from clear() have no deterministic order so they + // need to be sorted beforehand. + wantUnsortedEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction +// strategy keeps count of the dynamic entry count when all entries are +// cleared. +func TestNeighborCacheClearThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Clear the cache. + c.neigh.clear() + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + frequentlyUsedEntry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + + // The following logic is very similar to overflowCache, but + // periodically refreshes the frequently used entry. + + // Fill the neighbor cache to capacity + for i := 0; i < neighborCacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Keep adding more entries + for i := neighborCacheSize; i < store.size(); i++ { + // Periodically refresh the frequently used entry + if i%(neighborCacheSize/2) == 0 { + _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + } + } + + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // An entry should have been removed, as per the LRU eviction strategy + removedEntry, ok := store.entry(i - neighborCacheSize + 1) + if !ok { + t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the frequently used entry and the most recent entries. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + wantUnsortedEntries := []NeighborEntry{ + { + Addr: frequentlyUsedEntry.Addr, + LocalAddr: frequentlyUsedEntry.LocalAddr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, + }, + } + + for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheConcurrent(t *testing.T) { + const concurrentProcesses = 16 + + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + storeEntries := store.entries() + for _, entry := range storeEntries { + var wg sync.WaitGroup + for r := 0; r < concurrentProcesses; r++ { + wg.Add(1) + go func(entry NeighborEntry) { + defer wg.Done() + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + } + }(entry) + } + + // Wait for all gorountines to send a request + wg.Wait() + + // Process all the requests for a single entry concurrently + clock.advance(typicalLatency) + } + + // All goroutines add in the same order and add more values than can fit in + // the cache. Our eviction strategy requires that the last entries are + // present, up to the size of the neighbor cache, and the rest are missing. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + var wantUnsortedEntries []NeighborEntry + for i := store.size() - neighborCacheSize; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Errorf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheReplace(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add an entry + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // Verify the entry exists + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } + + // Notify of a link address change + var updatedLinkAddr tcpip.LinkAddress + { + entry, ok := store.entry(1) + if !ok { + t.Fatalf("store.entry(1) not found") + } + updatedLinkAddr = entry.LinkAddr + } + store.set(0, updatedLinkAddr) + neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + + // Requesting the entry again should start address resolution + { + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(config.DelayFirstProbeTime + typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + } + + // Verify the entry's new link address + { + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + clock.advance(typicalLatency) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want = NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + } +} + +func TestNeighborCacheResolutionFailed(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + + var requestCount uint32 + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + onLinkAddressRequest: func() { + atomic.AddUint32(&requestCount, 1) + }, + } + + // First, sanity check that resolution is working + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + + // Verify that address resolution for an unknown address returns ErrNoLinkAddress + before := atomic.LoadUint32(&requestCount) + + entry.Addr += "2" + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + maxAttempts := neigh.config().MaxUnicastProbes + if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } +} + +// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes +// probes and not retrieving a confirmation before the duration defined by +// MaxMulticastProbes * RetransmitTimer. +func TestNeighborCacheResolutionTimeout(t *testing.T) { + config := DefaultNUDConfigurations() + config.RetransmitTimer = time.Millisecond // small enough to cause timeout + + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: time.Minute, // large enough to cause timeout + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } +} + +// TestNeighborCacheStaticResolution checks that static link addresses are +// resolved immediately and don't send resolution requests. +func TestNeighborCacheStaticResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + } + want := NeighborEntry{ + Addr: testEntryBroadcastAddr, + LocalAddr: testEntryLocalAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + } +} + +func BenchmarkCacheClear(b *testing.B) { + b.StopTimer() + config := DefaultNUDConfigurations() + clock := &tcpip.StdClock{} + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: 0, + } + + // Clear for every possible size of the cache + for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + // Fill the neighbor cache to capacity. + for i := 0; i < cacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + b.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh != nil { + <-doneCh + } + } + + b.StartTimer() + neigh.clear() + b.StopTimer() + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go new file mode 100644 index 000000000..0068cacb8 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -0,0 +1,482 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NeighborEntry describes a neighboring device in the local network. +type NeighborEntry struct { + Addr tcpip.Address + LocalAddr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +// NeighborState defines the state of a NeighborEntry within the Neighbor +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +type NeighborState uint8 + +const ( + // Unknown means reachability has not been verified yet. This is the initial + // state of entries that have been created automatically by the Neighbor + // Unreachability Detection state machine. + Unknown NeighborState = iota + // Incomplete means that there is an outstanding request to resolve the + // address. + Incomplete + // Reachable means the path to the neighbor is functioning properly for both + // receive and transmit paths. + Reachable + // Stale means reachability to the neighbor is unknown, but packets are still + // able to be transmitted to the possibly stale link address. + Stale + // Delay means reachability to the neighbor is unknown and pending + // confirmation from an upper-level protocol like TCP, but packets are still + // able to be transmitted to the possibly stale link address. + Delay + // Probe means a reachability confirmation is actively being sought by + // periodically retransmitting reachability probes until a reachability + // confirmation is received, or until the max amount of probes has been sent. + Probe + // Static describes entries that have been explicitly added by the user. They + // do not expire and are not deleted until explicitly removed. + Static + // Failed means traffic should not be sent to this neighbor since attempts of + // reachability have returned inconclusive. + Failed +) + +// neighborEntry implements a neighbor entry's individual node behavior, as per +// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in +// parallel with the sending of packets to a neighbor, necessitating the +// entry's lock to be acquired for all operations. +type neighborEntry struct { + neighborEntryEntry + + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkRes provides the functionality to send reachability probes, used in + // Neighbor Unreachability Detection. + linkRes LinkAddressResolver + + // nudState points to the Neighbor Unreachability Detection configuration. + nudState *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + neigh NeighborEntry + + // wakers is a set of waiters for address resolution result. Anytime state + // transitions out of incomplete these waiters are notified. It is nil iff + // address resolution is ongoing and no clients are waiting for the result. + wakers map[*sleep.Waker]struct{} + + // done is used to allow callers to wait on address resolution. It is nil + // iff nudState is not Reachable and address resolution is not yet in + // progress. + done chan struct{} + + isRouter bool + job *tcpip.Job +} + +// newNeighborEntry creates a neighbor cache entry starting at the default +// state, Unknown. Transition out of Unknown by calling either +// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created +// neighborEntry. +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { + return &neighborEntry{ + nic: nic, + linkRes: linkRes, + nudState: nudState, + neigh: NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + State: Unknown, + }, + } +} + +// newStaticNeighborEntry creates a neighbor cache entry starting at the Static +// state. The entry can only transition out of Static by directly calling +// `setStateLocked`. +func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + if nic.stack.nudDisp != nil { + nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + } + return &neighborEntry{ + nic: nic, + nudState: state, + neigh: NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + }, + } +} + +// addWaker adds w to the list of wakers waiting for address resolution. +// Assumes the entry has already been appropriately locked. +func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { + if w == nil { + return + } + if e.wakers == nil { + e.wakers = make(map[*sleep.Waker]struct{}) + } + e.wakers[w] = struct{}{} +} + +// notifyWakersLocked notifies those waiting for address resolution, whether it +// succeeded or failed. Assumes the entry has already been appropriately locked. +func (e *neighborEntry) notifyWakersLocked() { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has +// been added. +func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry +// has changed state or link-layer address. +func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry +// has been removed. +func (e *neighborEntry) dispatchRemoveEventLocked() { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + } +} + +// setStateLocked transitions the entry to the specified state immediately. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +// +// e.mu MUST be locked. +func (e *neighborEntry) setStateLocked(next NeighborState) { + // Cancel the previously scheduled action, if there is one. Entries in + // Unknown, Stale, or Static state do not have scheduled actions. + if timer := e.job; timer != nil { + timer.Cancel() + } + + prev := e.neigh.State + e.neigh.State = next + e.neigh.UpdatedAt = time.Now() + config := e.nudState.Config() + + switch next { + case Incomplete: + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() + + case Reachable: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + }) + e.job.Schedule(e.nudState.ReachableTime()) + + case Delay: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Probe) + e.setStateLocked(Probe) + }) + e.job.Schedule(config.DelayFirstProbeTime) + + case Probe: + var retryCounter uint32 + var sendUnicastProbe func() + + sendUnicastProbe = func() { + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendUnicastProbe() + + case Failed: + e.notifyWakersLocked() + e.job = e.nic.stack.newJob(&e.mu, func() { + e.nic.neigh.removeEntryLocked(e) + }) + e.job.Schedule(config.UnreachableTime) + + case Unknown, Stale, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next)) + } +} + +// handlePacketQueuedLocked advances the state machine according to a packet +// being queued for outgoing transmission. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +func (e *neighborEntry) handlePacketQueuedLocked() { + switch e.neigh.State { + case Unknown: + e.dispatchAddEventLocked(Incomplete) + e.setStateLocked(Incomplete) + + case Stale: + e.dispatchChangeEventLocked(Delay) + e.setStateLocked(Delay) + + case Incomplete, Reachable, Delay, Probe, Static, Failed: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or +// Neighbor Solicitation for ARP or NDP, respectively). +// +// Follows the logic defined in RFC 4861 section 7.2.3. +func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { + // Probes MUST be silently discarded if the target address is tentative, does + // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These + // checks MUST be done by the NetworkEndpoint. + + switch e.neigh.State { + case Unknown, Incomplete, Failed: + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchAddEventLocked(Stale) + e.setStateLocked(Stale) + e.notifyWakersLocked() + + case Reachable, Delay, Probe: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + + case Stale: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + } + + case Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleConfirmationLocked processes an incoming neighbor confirmation +// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively). +// +// Follows the state machine defined by RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce Cryptographically +// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the +// claimed source of an NDP message is the owner of the claimed address. +func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + switch e.neigh.State { + case Incomplete: + if len(linkAddr) == 0 { + // "If the link layer has addresses and no Target Link-Layer Address + // option is included, the receiving node SHOULD silently discard the + // received advertisement." - RFC 4861 section 7.2.5 + break + } + + e.neigh.LinkAddr = linkAddr + if flags.Solicited { + e.dispatchChangeEventLocked(Reachable) + e.setStateLocked(Reachable) + } else { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + e.isRouter = flags.IsRouter + e.notifyWakersLocked() + + // "Note that the Override flag is ignored if the entry is in the + // INCOMPLETE state." - RFC 4861 section 7.2.5 + + case Reachable, Stale, Delay, Probe: + sameLinkAddr := e.neigh.LinkAddr == linkAddr + + if !sameLinkAddr { + if !flags.Override { + if e.neigh.State == Reachable { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + break + } + + e.neigh.LinkAddr = linkAddr + + if !flags.Solicited { + if e.neigh.State != Stale { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } else { + // Notify the LinkAddr change, even though NUD state hasn't changed. + e.dispatchChangeEventLocked(e.neigh.State) + } + break + } + } + + if flags.Solicited && (flags.Override || sameLinkAddr) { + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + } + // Set state to Reachable again to refresh timers. + e.setStateLocked(Reachable) + e.notifyWakersLocked() + } + + if e.isRouter && !flags.IsRouter { + // "In those cases where the IsRouter flag changes from TRUE to FALSE as + // a result of this update, the node MUST remove that router from the + // Default Router List and update the Destination Cache entries for all + // destinations using that neighbor as a router as specified in Section + // 7.3.3. This is needed to detect when a node that is used as a router + // stops forwarding packets due to being configured as a host." + // - RFC 4861 section 7.2.5 + e.nic.mu.Lock() + e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) + e.nic.mu.Unlock() + } + e.isRouter = flags.IsRouter + + case Unknown, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol +// (e.g. TCP acknowledgements) reachability confirmation. +func (e *neighborEntry) handleUpperLevelConfirmationLocked() { + switch e.neigh.State { + case Reachable, Stale, Delay, Probe: + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + // Set state to Reachable again to refresh timers. + } + e.setStateLocked(Reachable) + + case Unknown, Incomplete, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go new file mode 100644 index 000000000..08c9ccd25 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -0,0 +1,2770 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "math" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + + entryTestNICID tcpip.NICID = 1 + entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") + entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") + + // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match the + // MTU of loopback interfaces on Linux systems. + entryTestNetDefaultMTU = 65536 +) + +// eventDiffOpts are the options passed to cmp.Diff to compare entry events. +// The UpdatedAt field is ignored due to a lack of a deterministic method to +// predict the time that an event will be dispatched. +func eventDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + } +} + +// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to +// sort slices of events for cases where ordering must be ignored. +func eventDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +// The following unit tests exercise every state transition and verify its +// behavior with RFC 4681. +// +// | From | To | Cause | Action | Event | +// | ========== | ========== | ========================================== | =============== | ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | +// | Unknown | Stale | Probe w/ unknown address | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | +// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | +// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | +// | Reachable | Stale | Reachable timer expired | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | +// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | +// | Stale | Delay | Packet sent | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | Changed | +// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | Changed | +// | Delay | Probe | Delay timer expired | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | Changed | +// | Probe | Probe | Retransmit timer expired | Send probe | Changed | +// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | | Unreachability timer expired | Delete entry | | + +type testEntryEventType uint8 + +const ( + entryTestAdded testEntryEventType = iota + entryTestChanged + entryTestRemoved +) + +func (t testEntryEventType) String() string { + switch t { + case entryTestAdded: + return "add" + case entryTestChanged: + return "change" + case entryTestRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Fields are exported for use with cmp.Diff. +type testEntryEventInfo struct { + EventType testEntryEventType + NICID tcpip.NICID + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +func (e testEntryEventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) +} + +// testNUDDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type testNUDDispatcher struct { + mu sync.Mutex + events []testEntryEventInfo +} + +var _ NUDDispatcher = (*testNUDDispatcher)(nil) + +func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { + d.mu.Lock() + defer d.mu.Unlock() + d.events = append(d.events, e) +} + +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestAdded, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestChanged, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +type entryTestLinkResolver struct { + mu sync.Mutex + probes []entryTestProbeInfo +} + +var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) + +type entryTestProbeInfo struct { + RemoteAddress tcpip.Address + RemoteLinkAddress tcpip.LinkAddress + LocalAddress tcpip.Address +} + +func (p entryTestProbeInfo) String() string { + return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) +} + +// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts +// to the local network if linkAddr is the zero value. +func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + p := entryTestProbeInfo{ + RemoteAddress: addr, + RemoteLinkAddress: linkAddr, + LocalAddress: localAddr, + } + r.mu.Lock() + defer r.mu.Unlock() + r.probes = append(r.probes, p) + return nil +} + +// ResolveStaticAddress attempts to resolve address without sending requests. +// It either resolves the name immediately or returns the empty LinkAddress. +func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +// LinkAddressProtocol returns the network protocol of the addresses this +// resolver can resolve. +func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return entryTestNetNumber +} + +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { + clock := newFakeClock() + disp := testNUDDispatcher{} + nic := NIC{ + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + stack: &Stack{ + clock: clock, + nudDisp: &disp, + }, + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + nudState := NewNUDState(c, rng) + linkRes := entryTestLinkResolver{} + entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) + + // Stub out ndpState to verify modification of default routers. + nic.mu.ndp = ndpState{ + nic: &nic, + defaultRouters: make(map[tcpip.Address]defaultRouterState), + } + + // Stub out the neighbor cache to verify deletion from the cache. + nic.neigh = &neighborCache{ + nic: &nic, + state: nudState, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + nic.neigh.cache[entryTestAddr1] = entry + + return entry, &disp, &linkRes, clock +} + +// TestEntryInitiallyUnknown verifies that the state of a newly created +// neighborEntry is Unknown. +func TestEntryInitiallyUnknown(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(time.Hour) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToIncomplete(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + { + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +func TestEntryUnknownToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + updatedAt := e.neigh.UpdatedAt + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should remain the same during address resolution. + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should change after failing address resolution. Timing out after + // sending the last probe transitions the entry to Failed. + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + clock.advance(c.RetransmitTimer) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + } + e.mu.Unlock() +} + +func TestEntryIncompleteToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryAddsAndClearsWakers verifies that wakers are added when +// addWakerLocked is called and cleared when address resolution finishes. In +// this case, address resolution will finish when transitioning from Incomplete +// to Reachable. +func TestEntryAddsAndClearsWakers(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got := e.wakers; got != nil { + t.Errorf("got e.wakers = %v, want = nil", got) + } + e.addWakerLocked(&w) + if got, want := w.IsAsserted(), false; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + if e.wakers == nil { + t.Error("expected e.wakers to be non-nil") + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.wakers != nil { + t.Errorf("got e.wakers = %v, want = nil", e.wakers) + } + if got, want := w.IsAsserted(), true; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + linkRes.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +type testLocker struct{} + +var _ sync.Locker = (*testLocker)(nil) + +func (*testLocker) Lock() {} +func (*testLocker) Unlock() {} + +func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ + invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.isRouter, false; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { + t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryReachableToStaleWhenTimeout(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToDelay(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleUpperLevelConfirmationLocked() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToProbe(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryFailedGetsDeleted(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + // Verify the cache no longer contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { + t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) + } +} diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go new file mode 100644 index 000000000..aa7311ec6 --- /dev/null +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NeighborState"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[Incomplete-1] + _ = x[Reachable-2] + _ = x[Stale-3] + _ = x[Delay-4] + _ = x[Probe-5] + _ = x[Static-6] + _ = x[Failed-7] +} + +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" + +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} + +func (i NeighborState) String() string { + if i >= NeighborState(len(_NeighborState_index)-1) { + return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]] +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 54103fdb3..f21066fce 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math/rand" "reflect" "sort" "strings" @@ -45,6 +46,7 @@ type NIC struct { context NICContext stats NICStats + neigh *neighborCache mu struct { sync.RWMutex @@ -141,6 +143,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + // Check for Neighbor Unreachability Detection support. + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) + nic.neigh = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } + nic.linkEP.Attach(nic) return nic @@ -181,7 +193,7 @@ func (n *NIC) disableLocked() *tcpip.Error { return nil } - // TODO(b/147015577): Should Routes that are currently bound to n be + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be // invalidated? Currently, Routes will continue to work when a NIC is enabled // again, and applications may not know that the underlying NIC was ever // disabled. @@ -457,8 +469,20 @@ type ipv6AddrCandidate struct { // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { @@ -568,11 +592,6 @@ const ( // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous - - // forceSpoofing indicates that the NIC should be assumed to be spoofing, - // regardless of what the NIC's spoofing flag is when getting a NIC's - // referenced network endpoint. - forceSpoofing ) func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -591,8 +610,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - n.mu.RLock() var spoofingOrPromiscuous bool @@ -601,24 +618,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.spoofing case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous - case forceSpoofing: - spoofingOrPromiscuous = true } - if ref, ok := n.mu.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. - switch ref.getKind() { - case permanentExpired: - if !spoofingOrPromiscuous { - n.mu.RUnlock() - return nil - } - fallthrough - case temporary, permanent: - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } + if !ref.isAssignedRLocked(spoofingOrPromiscuous) { + n.mu.RUnlock() + return nil + } + + if ref.tryIncRef() { + n.mu.RUnlock() + return ref } } @@ -654,11 +665,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.mu.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -670,7 +688,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -680,8 +697,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t PrefixLen: netProto.DefaultPrefixLen(), }, }, peb, temporary, static, false) - - n.mu.Unlock() return ref } @@ -1153,7 +1168,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1182,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1197,27 +1212,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1229,18 +1251,19 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } } if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) + handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) return } @@ -1298,24 +1321,55 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this + // by having lower layers explicity write each header instead of just + // pkt.Header. + + // pkt may have set its NetworkHeader and TransportHeader. If we're + // forwarding, we'll have to copy them into pkt.Header. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) + if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) + } + if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1329,13 +1383,34 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader == nil { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } + + if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,7 +1437,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -1477,6 +1552,27 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { n.mu.Unlock() } +// NUDConfigs gets the NUD configurations for n. +func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { + if n.neigh == nil { + return NUDConfigurations{}, tcpip.ErrNotSupported + } + return n.neigh.config(), nil +} + +// setNUDConfigs sets the NUD configurations for n. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + c.resetInvalidFields() + n.neigh.setConfig(c) + return nil +} + // handleNDPRA handles an NDP Router Advertisement message that arrived on n. func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() @@ -1611,8 +1707,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in spoofing mode. +// packet. It requires the endpoint to not be marked expired (i.e., its address) +// has been removed) unless the NIC is in spoofing mode, or temporary. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { r.nic.mu.RLock() defer r.nic.mu.RUnlock() @@ -1620,13 +1716,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { return r.isValidForOutgoingRLocked() } -// isValidForOutgoingRLocked returns true if the endpoint can be used to send -// out a packet. It requires the endpoint to not be marked expired (i.e., its -// address has been removed), or the NIC to be in spoofing mode. -// -// r's NIC must be read locked. +// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires +// r.nic.mu to be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) + if !r.nic.mu.enabled { + return false + } + + return r.isAssignedRLocked(r.nic.mu.spoofing) +} + +// isAssignedRLocked returns true if r is considered to be assigned to the NIC. +// +// r.nic.mu must be read locked. +func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { + switch r.getKind() { + case permanentTentative: + return false + case permanentExpired: + return spoofingOrPromiscuous + default: + return true + } } // expireLocked decrements the reference count and marks the permanent endpoint diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index d672fc157..a70792b50 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,278 @@ package stack import ( + "math" "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ID implements NetworkEndpoint.ID. +func (e *testIPv6Endpoint) ID() *NetworkEndpointID { + return &e.id +} + +// PrefixLen implements NetworkEndpoint.PrefixLen. +func (e *testIPv6Endpoint) PrefixLen() int { + return e.prefixLen +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &testIPv6Endpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. @@ -44,7 +311,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go new file mode 100644 index 000000000..f848d50ad --- /dev/null +++ b/pkg/tcpip/stack/nud.go @@ -0,0 +1,466 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // defaultBaseReachableTime is the default base duration for computing the + // random reachable time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of a uniformly distributed random value between the minimum and + // maximum random factors, multiplied by the base reachable time. Using a + // random component eliminates the possibility that Neighbor Unreachability + // Detection messages will synchronize with each other. + // + // Default taken from REACHABLE_TIME of RFC 4861 section 10. + defaultBaseReachableTime = 30 * time.Second + + // minimumBaseReachableTime is the minimum base duration for computing the + // random reachable time. + // + // Minimum = 1ms + minimumBaseReachableTime = time.Millisecond + + // defaultMinRandomFactor is the default minimum value of the random factor + // used for computing reachable time. + // + // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10. + defaultMinRandomFactor = 0.5 + + // defaultMaxRandomFactor is the default maximum value of the random factor + // used for computing reachable time. + // + // The default value depends on the value of MinRandomFactor. + // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10, + // the value from the RFC will be used; otherwise, the default is + // MinRandomFactor multiplied by three. + defaultMaxRandomFactor = 1.5 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDelayFirstProbeTime is the default duration to wait for a + // non-Neighbor-Discovery related protocol to reconfirm reachability after + // entering the DELAY state. After this time, a reachability probe will be + // sent and the entry will transition to the PROBE state. + // + // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10. + defaultDelayFirstProbeTime = 5 * time.Second + + // defaultMaxMulticastProbes is the default number of reachabililty probes + // to send before concluding negative reachability and deleting the neighbor + // entry from the INCOMPLETE state. + // + // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10. + defaultMaxMulticastProbes = 3 + + // defaultMaxUnicastProbes is the default number of reachability probes to + // send before concluding retransmission from within the PROBE state should + // cease and the entry SHOULD be deleted. + // + // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10. + defaultMaxUnicastProbes = 3 + + // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD + // delay sending a response for a random time between 0 and this time, if the + // target address is an anycast address. + // + // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10. + defaultMaxAnycastDelayTime = time.Second + + // defaultMaxReachbilityConfirmations is the default amount of unsolicited + // reachability confirmation messages a node MAY send to all-node multicast + // address when it determines its link-layer address has changed. + // + // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. + defaultMaxReachbilityConfirmations = 3 + + // defaultUnreachableTime is the default duration for how long an entry will + // remain in the FAILED state before being removed from the neighbor cache. + // + // Note, there is no equivalent protocol constant defined in RFC 4861. It + // leaves the specifics of any garbage collection mechanism up to the + // implementation. + defaultUnreachableTime = 5 * time.Second +) + +// NUDDispatcher is the interface integrators of netstack must implement to +// receive and handle NUD related events. +type NUDDispatcher interface { + // OnNeighborAdded will be called when a new entry is added to a NIC's (with + // ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) + // neighbor table changes state and/or link address. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborRemoved will be called when an entry is removed from a NIC's + // (with ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) +} + +// ReachabilityConfirmationFlags describes the flags used within a reachability +// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, +// respectively). +type ReachabilityConfirmationFlags struct { + // Solicited indicates that the advertisement was sent in response to a + // reachability probe. + Solicited bool + + // Override indicates that the reachability confirmation should override an + // existing neighbor cache entry and update the cached link-layer address. + // When Override is not set the confirmation will not update a cached + // link-layer address, but will update an existing neighbor cache entry for + // which no link-layer address is known. + Override bool + + // IsRouter indicates that the sender is a router. + IsRouter bool +} + +// NUDHandler communicates external events to the Neighbor Unreachability +// Detection state machine, which is implemented per-interface. This is used by +// network endpoints to inform the Neighbor Cache of probes and confirmations. +type NUDHandler interface { + // HandleProbe processes an incoming neighbor probe (e.g. ARP request or + // Neighbor Solicitation for ARP or NDP, respectively). Validation of the + // probe needs to be performed before calling this function since the + // Neighbor Cache doesn't have access to view the NIC's assigned addresses. + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + + // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP + // reply or Neighbor Advertisement for ARP or NDP, respectively). + HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) + + // HandleUpperLevelConfirmation processes an incoming upper-level protocol + // (e.g. TCP acknowledgements) reachability confirmation. + HandleUpperLevelConfirmation(addr tcpip.Address) +} + +// NUDConfigurations is the NUD configurations for the netstack. This is used +// by the neighbor cache to operate the NUD state machine on each device in the +// local network. +type NUDConfigurations struct { + // BaseReachableTime is the base duration for computing the random reachable + // time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of uniformly distributed random value between minRandomFactor and + // maxRandomFactor multiplied by baseReachableTime. Using a random component + // eliminates the possibility that Neighbor Unreachability Detection messages + // will synchronize with each other. + // + // After this time, a neighbor entry will transition from REACHABLE to STALE + // state. + // + // Must be greater than 0. + BaseReachableTime time.Duration + + // LearnBaseReachableTime enables learning BaseReachableTime during runtime + // from the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option. + LearnBaseReachableTime bool + + // MinRandomFactor is the minimum value of the random factor used for + // computing reachable time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be greater than 0. + MinRandomFactor float32 + + // MaxRandomFactor is the maximum value of the random factor used for + // computing reachabile time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be great than or equal to MinRandomFactor. + MaxRandomFactor float32 + + // RetransmitTimer is the duration between retransmission of reachability + // probes in the PROBE state. + RetransmitTimer time.Duration + + // LearnRetransmitTimer enables learning RetransmitTimer during runtime from + // the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option. + LearnRetransmitTimer bool + + // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery + // related protocol to reconfirm reachability after entering the DELAY state. + // After this time, a reachability probe will be sent and the entry will + // transition to the PROBE state. + // + // Must be greater than 0. + DelayFirstProbeTime time.Duration + + // MaxMulticastProbes is the number of reachability probes to send before + // concluding negative reachability and deleting the neighbor entry from the + // INCOMPLETE state. + // + // Must be greater than 0. + MaxMulticastProbes uint32 + + // MaxUnicastProbes is the number of reachability probes to send before + // concluding retransmission from within the PROBE state should cease and + // entry SHOULD be deleted. + // + // Must be greater than 0. + MaxUnicastProbes uint32 + + // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a + // response for a random time between 0 and this time, if the target address + // is an anycast address. + // + // TODO(gvisor.dev/issue/2242): Use this option when sending solicited + // neighbor confirmations to anycast addresses and proxying neighbor + // confirmations. + MaxAnycastDelayTime time.Duration + + // MaxReachabilityConfirmations is the number of unsolicited reachability + // confirmation messages a node MAY send to all-node multicast address when + // it determines its link-layer address has changed. + // + // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD + // configuration option is necessary. + MaxReachabilityConfirmations uint32 + + // UnreachableTime describes how long an entry will remain in the FAILED + // state before being removed from the neighbor cache. + UnreachableTime time.Duration +} + +// DefaultNUDConfigurations returns a NUDConfigurations populated with default +// values defined by RFC 4861 section 10. +func DefaultNUDConfigurations() NUDConfigurations { + return NUDConfigurations{ + BaseReachableTime: defaultBaseReachableTime, + LearnBaseReachableTime: true, + MinRandomFactor: defaultMinRandomFactor, + MaxRandomFactor: defaultMaxRandomFactor, + RetransmitTimer: defaultRetransmitTimer, + LearnRetransmitTimer: true, + DelayFirstProbeTime: defaultDelayFirstProbeTime, + MaxMulticastProbes: defaultMaxMulticastProbes, + MaxUnicastProbes: defaultMaxUnicastProbes, + MaxAnycastDelayTime: defaultMaxAnycastDelayTime, + MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, + UnreachableTime: defaultUnreachableTime, + } +} + +// resetInvalidFields modifies an invalid NDPConfigurations with valid values. +// If invalid values are present in c, the corresponding default values will be +// used instead. This is needed to check, and conditionally fix, user-specified +// NUDConfigurations. +func (c *NUDConfigurations) resetInvalidFields() { + if c.BaseReachableTime < minimumBaseReachableTime { + c.BaseReachableTime = defaultBaseReachableTime + } + if c.MinRandomFactor <= 0 { + c.MinRandomFactor = defaultMinRandomFactor + } + if c.MaxRandomFactor < c.MinRandomFactor { + c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor) + } + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + if c.DelayFirstProbeTime == 0 { + c.DelayFirstProbeTime = defaultDelayFirstProbeTime + } + if c.MaxMulticastProbes == 0 { + c.MaxMulticastProbes = defaultMaxMulticastProbes + } + if c.MaxUnicastProbes == 0 { + c.MaxUnicastProbes = defaultMaxUnicastProbes + } + if c.UnreachableTime == 0 { + c.UnreachableTime = defaultUnreachableTime + } +} + +// calcMaxRandomFactor calculates the maximum value of the random factor used +// for computing reachable time. This function is necessary for when the +// default specified in RFC 4861 section 10 is less than the current +// MinRandomFactor. +// +// Assumes minRandomFactor is positive since validation of the minimum value +// should come before the validation of the maximum. +func calcMaxRandomFactor(minRandomFactor float32) float32 { + if minRandomFactor > defaultMaxRandomFactor { + return minRandomFactor * 3 + } + return defaultMaxRandomFactor +} + +// A Rand is a source of random numbers. +type Rand interface { + // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). + Float32() float32 +} + +// NUDState stores states needed for calculating reachable time. +type NUDState struct { + rng Rand + + // mu protects the fields below. + // + // It is necessary for NUDState to handle its own locking since neighbor + // entries may access the NUD state from within the goroutine spawned by + // time.AfterFunc(). This goroutine may run concurrently with the main + // process for controlling the neighbor cache and would otherwise introduce + // race conditions if NUDState was not locked properly. + mu sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 +} + +// NewNUDState returns new NUDState using c as configuration and the specified +// random number generator for use in recomputing ReachableTime. +func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { + s := &NUDState{ + rng: rng, + } + s.config = c + return s +} + +// Config returns the NUD configuration. +func (s *NUDState) Config() NUDConfigurations { + s.mu.RLock() + defer s.mu.RUnlock() + return s.config +} + +// SetConfig replaces the existing NUD configurations with c. +func (s *NUDState) SetConfig(c NUDConfigurations) { + s.mu.Lock() + defer s.mu.Unlock() + s.config = c +} + +// ReachableTime returns the duration to wait for a REACHABLE entry to +// transition into STALE after inactivity. This value is recalculated for new +// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the +// algorithm defined in RFC 4861 section 6.3.2. +func (s *NUDState) ReachableTime() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + + if time.Now().After(s.expiration) || + s.config.BaseReachableTime != s.prevBaseReachableTime || + s.config.MinRandomFactor != s.prevMinRandomFactor || + s.config.MaxRandomFactor != s.prevMaxRandomFactor { + return s.recomputeReachableTimeLocked() + } + return s.reachableTime +} + +// recomputeReachableTimeLocked forces a recalculation of ReachableTime using +// the algorithm defined in RFC 4861 section 6.3.2. +// +// This SHOULD automatically be invoked during certain situations, as per +// RFC 4861 section 6.3.4: +// +// If the received Reachable Time value is non-zero, the host SHOULD set its +// BaseReachableTime variable to the received value. If the new value +// differs from the previous value, the host SHOULD re-compute a new random +// ReachableTime value. ReachableTime is computed as a uniformly +// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR +// times the BaseReachableTime. Using a random component eliminates the +// possibility that Neighbor Unreachability Detection messages will +// synchronize with each other. +// +// In most cases, the advertised Reachable Time value will be the same in +// consecutive Router Advertisements, and a host's BaseReachableTime rarely +// changes. In such cases, an implementation SHOULD ensure that a new +// random value gets re-computed at least once every few hours. +// +// s.mu MUST be locked for writing. +func (s *NUDState) recomputeReachableTimeLocked() time.Duration { + s.prevBaseReachableTime = s.config.BaseReachableTime + s.prevMinRandomFactor = s.config.MinRandomFactor + s.prevMaxRandomFactor = s.config.MaxRandomFactor + + randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + + // Check for overflow, given that minRandomFactor and maxRandomFactor are + // guaranteed to be positive numbers. + if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { + s.reachableTime = time.Duration(math.MaxInt64) + } else if randomFactor == 1 { + // Avoid loss of precision when a large base reachable time is used. + s.reachableTime = s.config.BaseReachableTime + } else { + reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) + s.reachableTime = time.Duration(reachableTime) + } + + s.expiration = time.Now().Add(2 * time.Hour) + return s.reachableTime +} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go new file mode 100644 index 000000000..2494ee610 --- /dev/null +++ b/pkg/tcpip/stack/nud_test.go @@ -0,0 +1,795 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 + defaultMaxAnycastDelayTime = time.Second + defaultMaxReachbilityConfirmations = 3 + defaultUnreachableTime = 5 * time.Second + + defaultFakeRandomNum = 0.5 +) + +// fakeRand is a deterministic random number generator. +type fakeRand struct { + num float32 +} + +var _ stack.Rand = (*fakeRand)(nil) + +func (f *fakeRand) Float32() float32 { + return f.num +} + +// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NUD configurations using an invalid NICID. +func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to retrieve NUD configurations when the +// stack doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to set NUD configurations when the stack +// doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + } +} + +// TestDefaultNUDConfigurationIsValid verifies that calling +// resetInvalidFields() on the result of DefaultNUDConfigurations() does not +// change anything. DefaultNUDConfigurations() should return a valid +// NUDConfigurations. +func TestDefaultNUDConfigurations(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + c, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got, want := c, stack.DefaultNUDConfigurations(); got != want { + t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + } +} + +func TestNUDConfigurationsBaseReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + baseReachableTime: 0, + want: defaultBaseReachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + baseReachableTime: time.Millisecond, + want: time.Millisecond, + }, + { + name: "MoreThanDefaultBaseReachableTime", + baseReachableTime: 2 * defaultBaseReachableTime, + want: 2 * defaultBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = test.baseReachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.BaseReachableTime; got != test.want { + t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMinRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: -1, + want: defaultMinRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: 0, + want: defaultMinRandomFactor, + }, + // Valid cases + { + name: "MoreThanZero", + minRandomFactor: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MinRandomFactor; got != test.want { + t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + maxRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: -1, + want: defaultMaxRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: 0, + want: defaultMaxRandomFactor, + }, + { + name: "LessThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 0.99, + want: defaultMaxRandomFactor, + }, + { + name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", + minRandomFactor: defaultMaxRandomFactor * 2, + maxRandomFactor: defaultMaxRandomFactor, + want: defaultMaxRandomFactor * 6, + }, + // Valid cases + { + name: "EqualToMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor, + want: defaultMinRandomFactor, + }, + { + name: "MoreThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 1.1, + want: defaultMinRandomFactor * 1.1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxRandomFactor; got != test.want { + t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsRetransmitTimer(t *testing.T) { + tests := []struct { + name string + retransmitTimer time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + retransmitTimer: 0, + want: defaultRetransmitTimer, + }, + { + name: "LessThanMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer - time.Nanosecond, + want: defaultRetransmitTimer, + }, + // Valid cases + { + name: "EqualToMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer, + want: minimumBaseReachableTime, + }, + { + name: "LargetThanMinimumRetransmitTimer", + retransmitTimer: 2 * minimumBaseReachableTime, + want: 2 * minimumBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.RetransmitTimer = test.retransmitTimer + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.RetransmitTimer; got != test.want { + t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { + tests := []struct { + name string + delayFirstProbeTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + delayFirstProbeTime: 0, + want: defaultDelayFirstProbeTime, + }, + // Valid cases + { + name: "MoreThanZero", + delayFirstProbeTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.DelayFirstProbeTime = test.delayFirstProbeTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.DelayFirstProbeTime; got != test.want { + t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { + tests := []struct { + name string + maxMulticastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxMulticastProbes: 0, + want: defaultMaxMulticastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxMulticastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxMulticastProbes = test.maxMulticastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxMulticastProbes; got != test.want { + t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { + tests := []struct { + name string + maxUnicastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxUnicastProbes: 0, + want: defaultMaxUnicastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxUnicastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxUnicastProbes = test.maxUnicastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxUnicastProbes; got != test.want { + t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsUnreachableTime(t *testing.T) { + tests := []struct { + name string + unreachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + unreachableTime: 0, + want: defaultUnreachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + unreachableTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.UnreachableTime = test.unreachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.UnreachableTime; got != test.want { + t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +// TestNUDStateReachableTime verifies the correctness of the ReachableTime +// computation. +func TestNUDStateReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "AllZeros", + baseReachableTime: 0, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMaxRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMinRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 1, + want: time.Duration(defaultFakeRandomNum * float32(time.Second)), + }, + { + name: "FractionalRandomFactor", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 0.001, + maxRandomFactor: 0.002, + want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), + }, + { + name: "MinAndMaxRandomFactorsEqual", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Second, + }, + { + name: "MinAndMaxRandomFactorsDifferent", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 2, + want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), + }, + { + name: "MaxInt64", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Duration(math.MaxInt64), + }, + { + name: "Overflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1.5, + maxRandomFactor: 1.5, + want: time.Duration(math.MaxInt64), + }, + { + name: "DoubleOverflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 2.5, + maxRandomFactor: 2.5, + want: time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.NUDConfigurations{ + BaseReachableTime: test.baseReachableTime, + MinRandomFactor: test.minRandomFactor, + MaxRandomFactor: test.maxRandomFactor, + } + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} + +// TestNUDStateRecomputeReachableTime exercises the ReachableTime function +// twice to verify recomputation of reachable time when the min random factor, +// max random factor, or base reachable time changes. +func TestNUDStateRecomputeReachableTime(t *testing.T) { + const defaultBase = time.Second + const defaultMin = 2.0 * defaultMaxRandomFactor + const defaultMax = 3.0 * defaultMaxRandomFactor + + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "BaseReachableTime", + baseReachableTime: 2 * defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMax, + want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + { + name: "MinRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMax, + maxRandomFactor: defaultMax, + want: time.Duration(defaultMax * float32(defaultBase)), + }, + { + name: "MaxRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMin, + want: time.Duration(defaultMin * float32(defaultBase)), + }, + { + name: "BothRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), + }, + { + name: "BaseReachableTimeAndBothRandomFactors", + baseReachableTime: 2 * defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = defaultBase + c.MinRandomFactor = defaultMin + c.MaxRandomFactor = defaultMax + + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + old := s.ReachableTime() + + if got, want := s.ReachableTime(), old; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Check for recomputation when changing the min random factor, the max + // random factor, the base reachability time, or any permutation of those + // three options. + c.BaseReachableTime = test.baseReachableTime + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + s.SetConfig(c) + + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Verify that ReachableTime isn't recomputed when none of the + // configuration options change. The random factor is changed so that if + // a recompution were to occur, ReachableTime would change. + rng.num = defaultFakeRandomNum / 2.0 + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,6 +25,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ sync.NoCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -76,13 +79,31 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b331427c6..8604c4259 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -67,12 +71,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +104,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +122,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +154,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -168,6 +172,11 @@ type TransportProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -180,7 +189,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +198,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -240,17 +249,18 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have - // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + // protocol. It takes ownership of pkt. pkt.TransportHeader must have already + // been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. + // protocol. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + // header to the given destination address. It takes ownership of pkt. + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +275,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -312,11 +322,18 @@ type NetworkProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -326,7 +343,17 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -382,17 +409,17 @@ type LinkEndpoint interface { LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the - // given route. It sets pkt.LinkHeader if a link layer header exists. - // pkt.NetworkHeader and pkt.TransportHeader must have already been - // set. + // given route. It takes ownership of pkt. pkt.NetworkHeader and + // pkt.TransportHeader must have already been set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. pkts must not be zero length. + // given route. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may @@ -400,7 +427,7 @@ type LinkEndpoint interface { WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. + // should already have an ethernet header. It takes ownership of vv. WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer @@ -422,6 +449,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -430,7 +466,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. @@ -442,12 +478,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 150297ab9..91e0110f1 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -48,6 +48,10 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping + + // directedBroadcast indicates whether this route is sending a directed + // broadcast packet. + directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -113,6 +117,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -148,22 +154,27 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } @@ -175,9 +186,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } + // WritePackets takes ownership of pkt, calculate length first. + numPkts := pkts.Len() + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) @@ -193,17 +207,20 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } @@ -262,6 +279,12 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } +// IsBroadcast returns true if the route is to send a broadcast packet. +func (r *Route) IsBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast +} + // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..3f07e4159 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -424,12 +424,9 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -448,6 +445,9 @@ type Stack struct { // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + // nudConfigs is the default NUD configurations used by interfaces. + nudConfigs NUDConfigurations + // autoGenIPv6LinkLocal determines whether or not the stack will attempt // to auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. @@ -457,6 +457,10 @@ type Stack struct { // integrator NDP related events. ndpDisp NDPDispatcher + // nudDisp is the NUD event dispatcher that is used to send the netstack + // integrator NUD related events. + nudDisp NUDDispatcher + // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID @@ -475,6 +479,14 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption } // UniqueID is an abstract generator of unique identifiers. @@ -513,6 +525,9 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations + // NUDConfigs is the default NUD configurations used by interfaces. + NUDConfigs NUDConfigurations + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to // auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. @@ -531,6 +546,10 @@ type Options struct { // receive NDP related events. NDPDisp NDPDispatcher + // NUDDisp is the NUD event dispatcher that an integrator can provide to + // receive NUD related events. + NUDDisp NUDDispatcher + // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory @@ -665,6 +684,8 @@ func New(opts Options) *Stack { // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() + opts.NUDConfigs.resetInvalidFields() + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -676,16 +697,29 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, + nudConfigs: opts.NUDConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + nudDisp: opts.NUDDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. @@ -712,6 +746,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -778,16 +817,17 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -1020,6 +1060,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID @@ -1029,14 +1076,14 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { // Remove routes in-place. n tracks the number of routes written. n := 0 for i, r := range s.routeTable { + s.routeTable[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - if i > n { - s.routeTable[n] = r - } + s.routeTable[n] = r n++ } } + s.routeTable = s.routeTable[:n] return nic.remove() @@ -1072,6 +1119,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1103,6 +1155,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1249,9 +1302,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isBroadcast := remoteAddr == header.IPv4Broadcast + isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { @@ -1272,9 +1325,16 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - if needRoute { - r.NextHop = route.Gateway + r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) + + if len(route.Gateway) > 0 { + if needRoute { + r.NextHop = route.Gateway + } + } else if r.directedBroadcast { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } + return r, nil } } @@ -1400,25 +1460,31 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) +} + +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() s.cleanupEndpoints[ep] = struct{}{} - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CompleteTransportEndpointCleanup removes the endpoint from the cleanup @@ -1741,18 +1807,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent @@ -1831,10 +1887,38 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip } nic.setNDPConfigs(c) - return nil } +// NUDConfigurations gets the per-interface NUD configurations. +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return NUDConfigurations{}, tcpip.ErrUnknownNICID + } + + return nic.NUDConfigs() +} + +// SetNUDConfigurations sets the per-interface NUD configurations. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.setNUDConfigs(c) +} + // HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement // message that it needs to handle. func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..f22062889 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -52,6 +53,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -90,30 +95,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -132,24 +135,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) + pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] + pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] + pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + f.HandlePacket(r, pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -163,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -205,7 +203,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { @@ -247,6 +245,17 @@ func (*fakeNetworkProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(fakeNetHeaderLen) + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -292,8 +301,8 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 3 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -304,8 +313,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 1 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -316,8 +325,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 2 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +337,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +349,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +371,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +429,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -859,9 +868,9 @@ func TestRouteWithDownNIC(t *testing.T) { // Writes with Routes that use NIC1 after being brought up should // succeed. // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? if err := upFn(s, nicID1); err != nil { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } @@ -982,7 +991,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -1032,7 +1041,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -1114,7 +1123,7 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { if err := s.SetPromiscuousMode(nicID, true); err != nil { @@ -1277,7 +1286,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1658,7 +1667,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -1766,7 +1775,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -2263,7 +2272,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2344,8 +2353,8 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) - buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -3297,7 +3306,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { @@ -3330,3 +3339,305 @@ func TestDoDADWhenNICEnabled(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) } } + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID1 = 1 + ) + + defaultAddr := tcpip.AddressWithPrefix{ + Address: header.IPv4Any, + PrefixLen: 0, + } + defaultSubnet := defaultAddr.Subnet() + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedRoute stack.Route + }{ + // Broadcast to a locally attached subnet populates the broadcast MAC. + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: ipv4SubnetBcast, + RemoteLinkAddress: header.EthernetBroadcastAddress, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /31 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix31.Address, + RemoteAddress: ipv4Subnet31Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /32 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix32.Address, + RemoteAddress: ipv4Subnet32Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv6Addr.Address, + RemoteAddress: ipv6SubnetBcast, + NetProto: header.IPv6ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a remote subnet in the route table is send to the next-hop + // gateway. + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to an unknown subnet follows the default route. Note that this + // is essentially just routing an unknown destination IP, because w/o any + // subnet prefix information a subnet broadcast address is just a normal IP. + { + name: "IPv4 Broadcast to unknown subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: defaultSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) + } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { + t.Errorf("route mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,7 +15,6 @@ package stack import ( - "container/heap" "fmt" "math/rand" @@ -23,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" ) type protocolIDs struct { @@ -43,14 +43,14 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { return } delete(eps.endpoints, id) @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -214,23 +214,34 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t demux: d, netProto: netProto, transProto: transProto, - reuse: reusePort, } epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, reusePort) + return multiPortEp.singleRegisterEndpoint(t, flags) +} + +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + return nil + } + + return multiPortEp.singleCheckEndpoint(flags) } // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. -func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } - if multiPortEp.unregisterEndpoint(t) { + if multiPortEp.unregisterEndpoint(t, flags) { delete(epsByNIC.endpoints, bindToDevice) } return len(epsByNIC.endpoints) == 0 @@ -251,7 +262,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -279,10 +290,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) return err } } @@ -290,33 +301,15 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -type transportEndpointHeap []TransportEndpoint - -var _ heap.Interface = (*transportEndpointHeap)(nil) - -func (h *transportEndpointHeap) Len() int { - return len(*h) -} - -func (h *transportEndpointHeap) Less(i, j int) bool { - return (*h)[i].UniqueID() < (*h)[j].UniqueID() -} - -func (h *transportEndpointHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *transportEndpointHeap) Push(x interface{}) { - *h = append(*h, x.(TransportEndpoint)) -} +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { + return err + } + } -func (h *transportEndpointHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - old[n-1] = nil - *h = old[:n-1] - return x + return nil } // multiPortEndpoint is a container for TransportEndpoints which are bound to @@ -334,9 +327,10 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpoints transportEndpointHeap - // reuse indicates if more than one endpoint is allowed. - reuse bool + // endpoints stores the transport endpoints in the order in which they + // were bound. This is required for UDP SO_REUSEADDR. + endpoints []TransportEndpoint + flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,6 +356,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } + if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + return mpep.endpoints[len(mpep.endpoints)-1] + } + payload := []byte{ byte(id.LocalPort), byte(id.LocalPort >> 8), @@ -379,7 +377,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -401,40 +399,63 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + bits := flags.Bits() & ports.MultiBindFlagMask + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if !ep.reuse || !reusePort { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { return tcpip.ErrPortInUse } } - heap.Push(&ep.endpoints, t) + ep.endpoints = append(ep.endpoints, t) + ep.flags.AddRef(bits) + + return nil +} + +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. -func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() defer ep.mu.Unlock() for i, endpoint := range ep.endpoints { if endpoint == t { - heap.Remove(&ep.endpoints, i) + copy(ep.endpoints[i:], ep.endpoints[i+1:]) + ep.endpoints[len(ep.endpoints)-1] = nil + ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] + + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) break } } return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { - // TODO(eyalsoha): Why? - reusePort = false + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] @@ -454,15 +475,42 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.endpoints[id] = epsByNIC } - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) +} + +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol + } + + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil + } + + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep, bindToDevice) + eps.unregisterEndpoint(id, ep, flags, bindToDevice) } } } @@ -470,7 +518,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +568,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +592,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..73dada928 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -127,7 +128,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +166,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..7e8b84867 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -83,12 +84,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) + hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -153,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -198,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, /* reuse */ - 0, /* bindtoDevice */ + ports.Flags{}, + 0, /* bindtoDevice */ ); err != nil { return err } @@ -215,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -324,6 +326,17 @@ func (*fakeTransportProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeTransportProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) + if !ok { + return false + } + pkt.TransportHeader = hdr + pkt.Data.TrimFront(fakeTransHeaderLen) + return true +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -369,7 +382,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +393,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +404,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +459,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +470,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +481,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +636,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) @@ -642,11 +655,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.NetworkHeader[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.NetworkHeader[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } |