summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD18
-rw-r--r--pkg/tcpip/stack/conntrack.go688
-rw-r--r--pkg/tcpip/stack/forwarder_test.go11
-rw-r--r--pkg/tcpip/stack/iptables.go266
-rw-r--r--pkg/tcpip/stack/iptables_state.go40
-rw-r--r--pkg/tcpip/stack/iptables_targets.go23
-rw-r--r--pkg/tcpip/stack/iptables_types.go60
-rw-r--r--pkg/tcpip/stack/ndp.go140
-rw-r--r--pkg/tcpip/stack/ndp_test.go185
-rw-r--r--pkg/tcpip/stack/nic.go37
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go18
-rw-r--r--pkg/tcpip/stack/registration.go28
-rw-r--r--pkg/tcpip/stack/stack.go49
-rw-r--r--pkg/tcpip/stack/stack_options.go106
-rw-r--r--pkg/tcpip/stack/stack_test.go82
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go65
17 files changed, 1228 insertions, 598 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 24f52b735..6b9a6b316 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -27,6 +27,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -35,6 +47,7 @@ go_library(
"forwarder.go",
"icmp_rate_limit.go",
"iptables.go",
+ "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"linkaddrcache.go",
@@ -48,7 +61,9 @@ go_library(
"route.go",
"stack.go",
"stack_global_state.go",
+ "stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -78,6 +93,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
+ shard_count = 20,
deps = [
":stack",
"//pkg/rand",
@@ -93,7 +109,7 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 05bf62788..559a1c4dd 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -26,280 +26,321 @@ import (
)
// Connection tracking is used to track and manipulate packets for NAT rules.
-// The connection is created for a packet if it does not exist. Every connection
-// contains two tuples (original and reply). The tuples are manipulated if there
-// is a matching NAT rule. The packet is modified by looking at the tuples in the
-// Prerouting and Output hooks.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
// Direction of the tuple.
-type ctDirection int
+type direction int
const (
- dirOriginal ctDirection = iota
+ dirOriginal direction = iota
dirReply
)
-// Status of connection.
-// TODO(gvisor.dev/issue/170): Add other states of connection.
-type connStatus int
-
-const (
- connNew connStatus = iota
- connEstablished
-)
-
// Manipulation type for the connection.
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
-// connTrackMutable is the manipulatable part of the tuple.
-type connTrackMutable struct {
- // addr is source address of the tuple.
- addr tcpip.Address
-
- // port is source port of the tuple.
- port uint16
-
- // protocol is network layer protocol.
- protocol tcpip.NetworkProtocolNumber
-}
-
-// connTrackImmutable is the non-manipulatable part of the tuple.
-type connTrackImmutable struct {
- // addr is destination address of the tuple.
- addr tcpip.Address
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+//
+// +stateify savable
+type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
- // direction is direction (original or reply) of the tuple.
- direction ctDirection
+ tupleID
- // port is destination port of the tuple.
- port uint16
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
- // protocol is transport layer protocol.
- protocol tcpip.TransportProtocolNumber
+ // direction is the direction of the tuple.
+ direction direction
}
-// connTrackTuple represents the tuple which is created from the
-// packet.
-type connTrackTuple struct {
- // dst is non-manipulatable part of the tuple.
- dst connTrackImmutable
-
- // src is manipulatable part of the tuple.
- src connTrackMutable
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
}
-// connTrackTupleHolder is the container of tuple and connection.
-type ConnTrackTupleHolder struct {
- // conn is pointer to the connection tracking entry.
- conn *connTrack
-
- // tuple is original or reply tuple.
- tuple connTrackTuple
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
}
-// connTrack is the connection.
-type connTrack struct {
- // originalTupleHolder contains tuple in original direction.
- originalTupleHolder ConnTrackTupleHolder
-
- // replyTupleHolder contains tuple in reply direction.
- replyTupleHolder ConnTrackTupleHolder
-
- // status indicates connection is new or established.
- status connStatus
+// conn is a tracked connection.
+//
+// +stateify savable
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
- // timeout indicates the time connection should be active.
- timeout time.Duration
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
- // manip indicates if the packet should be manipulated.
+ // manip indicates if the packet should be manipulated. It is immutable.
manip manipType
- // tcb is TCB control block. It is used to keep track of states
- // of tcp connection.
- tcb tcpconntrack.TCB
-
// tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb.
+ // update the state of tcb. It is immutable.
tcbHook Hook
+
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
}
-// ConnTrackTable contains a map of all existing connections created for
-// NAT rules.
-type ConnTrackTable struct {
- // connMu protects connTrackTable.
- connMu sync.RWMutex
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
- // connTrackTable maintains a map of tuples needed for connection tracking
- // for iptables NAT rules. The key for the map is an integer calculated
- // using seed, source address, destination address, source port and
- // destination port.
- CtMap map[uint32]ConnTrackTupleHolder
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
+}
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
+type ConnTrack struct {
// seed is a one-time random value initialized at stack startup
- // and is used in calculation of hash key for connection tracking
- // table.
- Seed uint32
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
+
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
+
+ // buckets is protected by mu.
+ buckets []bucket
}
-// packetToTuple converts packet to a tuple in original direction.
-func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
- var tuple connTrackTuple
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
+}
- netHeader := header.IPv4(pkt.NetworkHeader)
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return tuple, tcpip.ErrUnknownProtocol
+ return tupleID{}, tcpip.ErrUnknownProtocol
}
- tuple.src.addr = netHeader.SourceAddress()
- tuple.src.port = tcpHeader.SourcePort()
- tuple.src.protocol = header.IPv4ProtocolNumber
-
- tuple.dst.addr = netHeader.DestinationAddress()
- tuple.dst.port = tcpHeader.DestinationPort()
- tuple.dst.protocol = netHeader.TransportProtocol()
-
- return tuple, nil
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
}
-// getReplyTuple creates reply tuple for the given tuple.
-func getReplyTuple(tuple connTrackTuple) connTrackTuple {
- var replyTuple connTrackTuple
- replyTuple.src.addr = tuple.dst.addr
- replyTuple.src.port = tuple.dst.port
- replyTuple.src.protocol = tuple.src.protocol
- replyTuple.dst.addr = tuple.src.addr
- replyTuple.dst.port = tuple.src.port
- replyTuple.dst.protocol = tuple.dst.protocol
- replyTuple.dst.direction = dirReply
-
- return replyTuple
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
}
-// makeNewConn creates new connection.
-func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
- var conn connTrack
- conn.status = connNew
- conn.originalTupleHolder.tuple = tuple
- conn.originalTupleHolder.conn = &conn
- conn.replyTupleHolder.tuple = replyTuple
- conn.replyTupleHolder.conn = &conn
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
- return conn
-}
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
+ }
-// getTupleHash returns hash of the tuple. The fields used for
-// generating hash are seed (generated once for stack), source address,
-// destination address, source port and destination ports.
-func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
- h := jenkins.Sum32(ct.Seed)
- h.Write([]byte(tuple.src.addr))
- h.Write([]byte(tuple.dst.addr))
- portBuf := make([]byte, 2)
- binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
- h.Write([]byte(portBuf))
- binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
- h.Write([]byte(portBuf))
-
- return h.Sum32()
+ return nil, dirOriginal
}
-// connTrackForPacket returns connTrack for packet.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
-// transport protocols.
-func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- var dir ctDirection
- tuple, err := packetToTuple(pkt, hook)
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
if err != nil {
- return nil, dir
- }
-
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
-
- connTrackTable := ct.CtMap
- hash := ct.getTupleHash(tuple)
-
- var conn *connTrack
- switch createConn {
- case true:
- // If connection does not exist for the hash, create a new
- // connection.
- replyTuple := getReplyTuple(tuple)
- replyHash := ct.getTupleHash(replyTuple)
- newConn := makeNewConn(tuple, replyTuple)
- conn = &newConn
-
- // Add tupleHolders to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
- ct.CtMap[hash] = conn.originalTupleHolder
- ct.CtMap[replyHash] = conn.replyTupleHolder
- default:
- tupleHolder, ok := connTrackTable[hash]
- if !ok {
- return nil, dir
- }
-
- // If this is the reply of new connection, set the connection
- // status as ESTABLISHED.
- conn = tupleHolder.conn
- if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
- conn.status = connEstablished
- }
- if tupleHolder.conn == nil {
- panic("tupleHolder has null connection tracking entry")
- }
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
- dir = tupleHolder.tuple.dst.direction
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
}
- return conn, dir
+ conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
}
-// SetNatInfo will manipulate the tuples according to iptables NAT rules.
-func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
- // Get the connection. Connection is always created before this
- // function is called.
- conn, _ := ct.connTrackForPacket(pkt, hook, false)
- if conn == nil {
- panic("connection should be created to manipulate tuples.")
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
}
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
- // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
- // support changing of address for Prerouting.
-
- // Change the port as per the iptables rule. This tuple will be used
- // to manipulate the packet in HandlePacket.
- conn.replyTupleHolder.tuple.src.addr = rt.MinIP
- conn.replyTupleHolder.tuple.src.port = rt.MinPort
- newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
- // Add the changed tuple to the map.
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
- ct.CtMap[newHash] = conn.replyTupleHolder
- if hook == Output {
- conn.replyTupleHolder.conn.manip = manipDstOutput
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
}
- // Delete the old tuple.
- delete(ct.CtMap, replyHash)
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
-func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -308,21 +349,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection)
// modified.
switch dir {
case dirOriginal:
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -331,13 +382,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.replyTupleHolder.tuple.src.port
+ port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
- port := conn.originalTupleHolder.tuple.dst.port
+ port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
@@ -356,33 +407,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route,
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
-// HandlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
}
- conn, dir := ct.connTrackForPacket(pkt, hook, false)
- // Connection or Rule not found for the packet.
- if conn == nil {
- return
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- // TODO(gvisor.dev/issue/170): Need to support for other transport
- // protocols as well.
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return
+ conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return true
}
tcpHeader := header.TCP(pkt.TransportHeader)
if tcpHeader == nil {
- return
+ return false
}
switch hook {
@@ -396,39 +446,161 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
// other tcp states.
- var st tcpconntrack.Result
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
- }
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
}
- // Delete conntrack if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConnTrack(conn)
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
}
-}
-// deleteConnTrack deletes the connection.
-func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
- if conn == nil {
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
+}
- tuple := conn.originalTupleHolder.tuple
- hash := ct.getTupleHash(tuple)
- replyTuple := conn.replyTupleHolder.tuple
- replyHash := ct.getTupleHash(replyTuple)
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
- ct.connMu.Lock()
- defer ct.connMu.Unlock()
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
+
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
- delete(ct.CtMap, hash)
- delete(ct.CtMap, replyHash)
+ return true
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index a6546cef0..bca1d940b 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -301,6 +302,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 4e9b404c8..cbbae4224 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,39 +16,49 @@ package stack
import (
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -56,65 +66,71 @@ func DefaultTables() *IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
- connections: ConnTrackTable{
- CtMap: make(map[uint32]ConnTrackTupleHolder),
- Seed: generateRandUint32(),
+ connections: ConnTrack{
+ seed: generateRandUint32(),
},
+ reaperDone: make(chan struct{}, 1),
}
}
@@ -123,69 +139,59 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
+// GetTable returns a table by name.
func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ return it.tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
- it.mu.Lock()
- defer it.mu.Unlock()
- it.tables[name] = table
-}
-
-// ModifyTables acquires write-lock and calls fn with internal name-to-table
-// map. This function can be used to update multiple tables atomically.
-func (it *IPTables) ModifyTables(fn func(map[string]Table)) {
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
- fn(it.tables)
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
+ it.modified = true
+ it.tables[id] = table
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -209,13 +215,30 @@ const (
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ if !it.modified {
+ it.mu.RUnlock()
+ return true
+ }
+ it.mu.RUnlock()
+
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.HandlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -244,17 +267,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -278,9 +343,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
return drop, natPkts
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -325,23 +390,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
- if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
- }
-
// Check whether the packet matches the IP header filter.
if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
// Continue on to the next rule.
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
new file mode 100644
index 000000000..529e02a07
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
+
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 92e31643e..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -24,7 +24,7 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -92,7 +92,7 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -150,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
- // Set up conection for matching NAT rule.
- // Only the first packet of the connection comes here.
- // Other packets will be manipulated in connection tracking.
- if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil {
- ct.SetNatInfo(pkt, rt, hook)
- ct.HandlePacket(pkt, hook, gso, r)
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
}
default:
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4a6a5c6f1..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -78,67 +78,65 @@ const (
)
// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
type IPTables struct {
- // mu protects tables and priorities.
+ // mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names. mu
- // needs to be locked for accessing.
- tables map[string]Table
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
+ priorities [NumHooks][]tableID
+
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
- connections ConnTrackTable
+ connections ConnTrack
+
+ // reaperDone can be signalled to stop the reaper goroutine.
+ reaperDone chan struct{}
}
// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules with some metadata for entrypoints and such.
+// really just a list of rules.
+//
+// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
-
- // Metadata holds information about the Table that is useful to users
- // of IPTables, but not to the netstack IPTables code itself.
- metadata interface{}
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
-// Metadata returns the metadata object stored in table.
-func (table *Table) Metadata() interface{} {
- return table.metadata
-}
-
-// SetMetadata sets the metadata object stored in table.
-func (table *Table) SetMetadata(metadata interface{}) {
- table.metadata = metadata
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
+//
+// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
@@ -151,6 +149,8 @@ type Rule struct {
}
// IPHeaderFilter holds basic IP filtering data common to every rule.
+//
+// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
@@ -258,5 +258,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+ Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index e28c23d66..9dce11a97 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -469,7 +469,7 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
@@ -503,7 +503,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -515,38 +515,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -561,15 +561,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
ndp.nic.mu.Lock()
defer ndp.nic.mu.Unlock()
@@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
ndp.nic.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index ae326b3ab..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -36,15 +36,24 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
- defaultTimeout = 100 * time.Millisecond
- defaultAsyncEventTimeout = time.Second
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -442,7 +451,7 @@ func TestDADResolve(t *testing.T) {
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
} else if want := (tcpip.AddressWithPrefix{}); addr != want {
@@ -471,7 +480,7 @@ func TestDADResolve(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
@@ -1169,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1245,14 +1254,14 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1262,14 +1271,14 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1418,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1493,14 +1502,14 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1565,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with finite lifetime.
@@ -1590,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with a prefix with a lifetime value greater than the
@@ -1599,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with 0 lifetime.
@@ -1835,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1962,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -1975,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2081,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
@@ -2180,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -2188,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -2265,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2273,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2363,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2386,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2398,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
@@ -2407,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event = %+v", e)
}
- case <-time.After(invalidateAfter + defaultAsyncEventTimeout):
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
@@ -2423,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2517,14 +2526,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Prefer the prefix again.
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2546,24 +2555,24 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2711,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2724,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2984,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3070,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3110,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3124,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3156,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -3165,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3295,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3439,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
}
// Wait for the invalidation event.
@@ -3448,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3504,12 +3513,12 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -3672,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -3770,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3837,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3863,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -4030,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4149,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4251,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -4277,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4285,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4869,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
select {
case <-ndpDisp.routerC:
t.Error("unexpected router event")
@@ -5172,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
- waitForPkt(2 * defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
} else {
- waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
}
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
}
// Make sure the counter got properly
@@ -5305,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -5319,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
@@ -5332,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -5344,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index afb7dfeaf..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
@@ -1358,16 +1374,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// TransportHeader is nil only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader == nil {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
pkt.TransportHeader = transHeader
+ pkt.Data.TrimFront(len(pkt.TransportHeader))
} else {
// This is either a bad packet or was re-assembled from fragments.
transProto.Parse(pkt)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 31f865260..c477e31d8 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -84,6 +84,16 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 1b5da6017..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,6 +14,7 @@
package stack
import (
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -24,7 +25,7 @@ import (
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
type PacketBuffer struct {
- _ noCopy
+ _ sync.NoCopy
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
@@ -78,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
@@ -102,14 +107,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NatDone: pk.NatDone,
}
}
-
-// noCopy may be embedded into structs which must not be copied
-// after the first use.
-//
-// See https://golang.org/issues/8005#issuecomment-190753527
-// for details.
-type noCopy struct{}
-
-// Lock is a no-op used by -copylocks checker from `go vet`.
-func (*noCopy) Lock() {}
-func (*noCopy) Unlock() {}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5cbc946b6..9e1b2d25f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -329,8 +333,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -341,6 +344,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -436,6 +449,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a2190341c..a6faa22c2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -425,6 +425,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
+ // TODO(gvisor.dev/issue/170): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -471,6 +472,14 @@ type Stack struct {
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
randomGenerator *mathrand.Rand
+
+ // sendBufferSize holds the min/default/max send buffer sizes for
+ // endpoints other than TCP.
+ sendBufferSize SendBufferSizeOption
+
+ // receiveBufferSize holds the min/default/max receive buffer sizes for
+ // endpoints other than TCP.
+ receiveBufferSize ReceiveBufferSizeOption
}
// UniqueID is an abstract generator of unique identifiers.
@@ -683,6 +692,16 @@ func New(opts Options) *Stack {
tempIIDSeed: opts.TempIIDSeed,
forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ receiveBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
}
// Add specified network protocols.
@@ -709,6 +728,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -782,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -1033,14 +1058,14 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
// Remove routes in-place. n tracks the number of routes written.
n := 0
for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
if r.NIC != id {
// Keep this route.
- if i > n {
- s.routeTable[n] = r
- }
+ s.routeTable[n] = r
n++
}
}
+
s.routeTable = s.routeTable[:n]
return nic.remove()
@@ -1076,6 +1101,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1107,6 +1137,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -1408,6 +1439,12 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N
return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
+// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
+// the stack transport dispatcher.
+func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
+}
+
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
new file mode 100644
index 000000000..0b093e6c5
--- /dev/null
+++ b/pkg/tcpip/stack/stack_options.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4 KiB
+
+ // DefaultBufferSize is the default size of the send/recv buffer for a
+ // transport endpoint.
+ DefaultBufferSize = 212 << 10 // 212 KiB
+
+ // DefaultMaxBufferSize is the default maximum permitted size of a
+ // send/receive buffer.
+ DefaultMaxBufferSize = 4 << 20 // 4 MiB
+)
+
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// SetOption allows setting stack wide options.
+func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SendBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.sendBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.receiveBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option allows retrieving stack wide options.
+func (s *Stack) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SendBufferSizeOption:
+ s.mu.RLock()
+ *v = s.sendBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ s.mu.RLock()
+ *v = s.receiveBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ffef9bc2c..7657a4101 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -3305,7 +3305,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
@@ -3338,3 +3338,83 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
}
}
+
+func TestStackReceiveBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ rs stack.ReceiveBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.rs); err != tc.err {
+ t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
+ }
+ var rs stack.ReceiveBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&rs); err != nil {
+ t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
+ }
+ if got, want := rs, tc.rs; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestStackSendBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ ss stack.SendBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.ss); err != tc.err {
+ t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ }
+ var ss stack.SendBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&ss); err != nil {
+ t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
+ }
+ if got, want := ss, tc.ss; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 118b449d5..b902c6ca9 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -221,6 +221,18 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ return nil
+ }
+
+ return multiPortEp.singleCheckEndpoint(flags)
+}
+
// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
epsByNIC.mu.Lock()
@@ -289,6 +301,17 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
return nil
}
+// checkEndpoint checks if an endpoint can be registered with the dispatcher.
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for _, n := range netProtos {
+ if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
@@ -380,7 +403,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
ep.mu.Lock()
defer ep.mu.Unlock()
- bits := flags.Bits()
+ bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
@@ -395,6 +418,22 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
return nil
}
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ return nil
+}
+
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
@@ -406,7 +445,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports
ep.endpoints[len(ep.endpoints)-1] = nil
ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
- ep.flags.DropRef(flags.Bits())
+ ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
break
}
}
@@ -439,6 +478,28 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
}
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return nil
+ }
+
+ return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {