diff options
Diffstat (limited to 'pkg/tcpip')
23 files changed, 816 insertions, 221 deletions
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..69de6eb3e 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -69,13 +69,12 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { // NonBlockingWrite3 writes up to three byte slices to a file descriptor in a // single syscall. It fails if partial data is written. func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If the is no second buffer, issue a regular write. - if len(b2) == 0 { + // If there is no second and third buffer, issue a regular write. + if len(b2) == 0 && len(b3) == 0 { return NonBlockingWrite(fd, b1) } - // We have two buffers. Build the iovec that represents them and issue - // a writev syscall. + // Build the iovec that represents them and issue a writev syscall. iovec := [3]syscall.Iovec{ { Base: &b1[0], diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7e9f16c90..b1776e5ee 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -225,12 +225,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -376,13 +374,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 794ddb5c8..800bf3f08 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -27,6 +27,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,6 +47,7 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", @@ -50,6 +63,7 @@ go_library( "stack_global_state.go", "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -79,6 +93,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index af9c325ca..d39baf620 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,9 +15,12 @@ package stack import ( + "encoding/binary" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) @@ -30,6 +33,10 @@ import ( // // Currently, only TCP tracking is supported. +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + // Direction of the tuple. type direction int @@ -48,7 +55,12 @@ const ( // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. +// +// +stateify savable type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + tupleID // conn is the connection tracking entry this tuple belongs to. @@ -61,6 +73,8 @@ type tuple struct { // tupleID uniquely identifies a connection in one direction. It currently // contains enough information to distinguish between any TCP or UDP // connection, and will need to be extended to support other protocols. +// +// +stateify savable type tupleID struct { srcAddr tcpip.Address srcPort uint16 @@ -83,6 +97,8 @@ func (ti tupleID) reply() tupleID { } // conn is a tracked connection. +// +// +stateify savable type conn struct { // original is the tuple in original direction. It is immutable. original tuple @@ -98,22 +114,67 @@ type conn struct { tcbHook Hook // mu protects tcb. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB + + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout } // ConnTrack tracks all connections created for NAT rules. Most users are // expected to only call handlePacket and createConnFor. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable type ConnTrack struct { - // mu protects conns. - mu sync.RWMutex + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // conns maintains a map of tuples needed for connection tracking for - // iptables NAT rules. It is protected by mu. - conns map[tupleID]tuple + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList } // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid @@ -143,8 +204,9 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // newConn creates new connection. func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { conn := conn{ - manip: manip, - tcbHook: hook, + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), } conn.original = tuple{conn: &conn, tupleID: orig} conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} @@ -162,14 +224,28 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } - ct.mu.Lock() - defer ct.mu.Unlock() - - tuple, ok := ct.conns[tid] - if !ok { - return nil, dirOriginal + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - return tuple.conn, tuple.direction + + return nil, dirOriginal } // createConnFor creates a new conn for pkt. @@ -197,13 +273,31 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg } conn := newConn(tid, replyTID, manip, hook) - // Add the changed tuple to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked - // list. - ct.mu.Lock() - defer ct.mu.Unlock() - ct.conns[tid] = conn.original - ct.conns[replyTID] = conn.reply + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(tid) + replyBucket := ct.bucket(replyTID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } + + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } return conn } @@ -297,35 +391,134 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() - var st tcpconntrack.Result - tcpHeader := header.TCP(pkt.TransportHeader) - if conn.tcb.IsEmpty() { + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { conn.tcb.Init(tcpHeader) conn.tcbHook = hook + } else if hook == conn.tcbHook { + conn.tcb.UpdateStateOutbound(tcpHeader) } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.tcb.UpdateStateInbound(tcpHeader) } +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} - // Delete conn if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConn(conn) +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval } + // We've hit the maximum interval. + return idx, maxInterval } -// deleteConn deletes the connection. -func (ct *ConnTrack) deleteConn(conn *conn) { - if conn == nil { - return +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false } - ct.mu.Lock() - defer ct.mu.Unlock() + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.conns, conn.original.tupleID) - delete(ct.conns, conn.reply.tupleID) + return true } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 974d77c36..f846ea2e5 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -41,6 +42,9 @@ const ( // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { @@ -112,8 +116,9 @@ func DefaultTables() *IPTables { Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, connections: ConnTrack{ - conns: make(map[tupleID]tuple), + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -169,6 +174,12 @@ func (it *IPTables) GetTable(name string) (Table, bool) { func (it *IPTables) ReplaceTable(name string, table Table) { it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true it.tables[name] = table } @@ -249,6 +260,35 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index c528ec381..eb70e3104 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,6 +78,8 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { // mu protects tables, priorities, and modified. mu sync.RWMutex @@ -97,10 +99,15 @@ type IPTables struct { modified bool connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is // really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -130,6 +137,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -142,6 +151,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..7b80534e6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1358,16 +1358,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // TransportHeader is nil only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..e3556d5d2 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,7 +25,7 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. @@ -102,14 +103,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NatDone: pk.NatDone, } } - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cdcfb8321..0aa815447 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2be1c107a..71bcee785 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -648,6 +648,11 @@ const ( // whether an IPv6 socket is to be restricted to sending and receiving // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. @@ -673,6 +678,13 @@ const ( // TCP_MAXSEG option. MaxSegOption + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control // the default TTL value for multicast messages. The default is 1. MulticastTTLOption @@ -714,6 +726,24 @@ const ( TCPWindowClampOption ) +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota + + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont + + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo + + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) + // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} @@ -752,7 +782,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -825,7 +855,10 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1214,6 +1247,9 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..5554c573f 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,8 +15,9 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) // cancellableTimerInstance is a specific instance of CancellableTimer. @@ -92,6 +93,8 @@ func (t *cancellableTimerInstance) stop() { // Note, it is not safe to copy a CancellableTimer as its timer instance creates // a closure over the address of the CancellableTimer. type CancellableTimer struct { + _ sync.NoCopy + // The active instance of a cancellable timer. instance cancellableTimerInstance @@ -157,22 +160,6 @@ func (t *CancellableTimer) Reset(d time.Duration) { } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - // NewCancellableTimer returns an unscheduled CancellableTimer with the given // locker and fn. // diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..678f4e016 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } @@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,7 +790,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index baf08eda6..57b7f5c19 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,6 +25,8 @@ package packet import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -71,11 +73,12 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool } // NewEndpoint returns a new packet endpoint. @@ -92,6 +95,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } @@ -264,7 +278,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -274,7 +294,46 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -289,7 +348,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // HandlePacket implements stack.PacketEndpoint.HandlePacket. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 766c7648e..c2e9fd29f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if !e.associated { + if e.hdrIncluded { if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -508,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -577,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -616,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 6baeda8e4..18ff89ffc 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,6 +86,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 047704c80..98aecab9e 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -15,6 +15,8 @@ package tcp import ( + "encoding/binary" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -66,89 +68,68 @@ func (q *epQueue) empty() bool { // processor is responsible for processing packets queued to a tcp endpoint. type processor struct { epQ epQueue + sleeper sleep.Sleeper newEndpointWaker sleep.Waker closeWaker sleep.Waker - id int - wg sync.WaitGroup -} - -func newProcessor(id int) *processor { - p := &processor{ - id: id, - } - p.wg.Add(1) - go p.handleSegments() - return p } func (p *processor) close() { p.closeWaker.Assert() } -func (p *processor) wait() { - p.wg.Wait() -} - func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) p.newEndpointWaker.Assert() } -func (p *processor) handleSegments() { - const newEndpointWaker = 1 - const closeWaker = 2 - s := sleep.Sleeper{} - s.AddWaker(&p.newEndpointWaker, newEndpointWaker) - s.AddWaker(&p.closeWaker, closeWaker) - defer s.Done() +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + for { - id, ok := s.Fetch(true) - if ok && id == closeWaker { - p.wg.Done() - return + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break } - for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } if ep.segmentQueue.empty() { continue } - // If socket has transitioned out of connected state - // then just let the worker handle the packet. + // If socket has transitioned out of connected state then just let the + // worker handle the packet. // - // NOTE: We read this outside of e.mu lock which means - // that by the time we get to handleSegments the - // endpoint may not be in ESTABLISHED. But this should - // be fine as all normal shutdown states are handled by - // handleSegments and if the endpoint moves to a - // CLOSED/ERROR state then handleSegments is a noop. - if ep.EndpointState() != StateEstablished { - ep.newSegmentWaker.Assert() - continue - } - - if !ep.mu.TryLock() { - ep.newSegmentWaker.Assert() - continue - } - // If the endpoint is in a connected state then we do - // direct delivery to ensure low latency and avoid - // scheduler interactions. - if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { - // Send any active resets if required. - if err != nil { + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) } - ep.notifyProtocolGoroutine(notifyTickleWorker) ep.mu.Unlock() - continue - } - - if !ep.segmentQueue.empty() { - p.epQ.enqueue(ep) + } else { + ep.newSegmentWaker.Assert() } - - ep.mu.Unlock() } } } @@ -159,31 +140,36 @@ func (p *processor) handleSegments() { // hash of the endpoint id to ensure that delivery for the same endpoint happens // in-order. type dispatcher struct { - processors []*processor + processors []processor seed uint32 -} - -func newDispatcher(nProcessors int) *dispatcher { - processors := []*processor{} - for i := 0; i < nProcessors; i++ { - processors = append(processors, newProcessor(i)) - } - return &dispatcher{ - processors: processors, - seed: generateRandUint32(), + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) } } func (d *dispatcher) close() { - for _, p := range d.processors { - p.close() + for i := range d.processors { + d.processors[i].close() } } func (d *dispatcher) wait() { - for _, p := range d.processors { - p.wait() - } + d.wg.Wait() } func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -231,20 +217,18 @@ func generateRandUint32() uint32 { if _, err := rand.Read(b); err != nil { panic(err) } - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return binary.LittleEndian.Uint32(b) } func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { - payload := []byte{ - byte(id.LocalPort), - byte(id.LocalPort >> 8), - byte(id.RemotePort), - byte(id.RemotePort >> 8)} + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) h := jenkins.Sum32(d.seed) - h.Write(payload) + h.Write(payload[:]) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) - return d.processors[h.Sum32()%uint32(len(d.processors))] + return &d.processors[h.Sum32()%uint32(len(d.processors))] } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bd3ec5a8d..83dc10ed0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1589,6 +1589,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1785,6 +1792,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } @@ -1896,6 +1906,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index f2ae6ce50..b34e47bbd 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -174,7 +174,7 @@ type protocol struct { maxRetries uint32 synRcvdCount synRcvdCounter synRetries uint8 - dispatcher *dispatcher + dispatcher dispatcher } // Number returns the tcp protocol number. @@ -515,7 +515,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ + p := protocol{ sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, @@ -531,10 +531,11 @@ func NewProtocol() stack.TransportProtocol { tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, - dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 169adb16b..e67ec42b1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3095,6 +3095,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cae29fbff..a14643ae8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -612,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -809,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -906,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -1366,6 +1380,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + // Verify checksum unless RX checksum offload is enabled. // On IPv4, UDP checksum is optional, and a zero value means // the transmitter omitted the checksum generation (RFC768). @@ -1384,10 +1407,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index db59eb5a0..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1721,9 +1793,11 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Invalidate the packet length field in the UDP header by adding one. + + // Invalidate the UDP header length field. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -1803,9 +1877,16 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -1834,9 +1915,11 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Invalidate the checksum field in the UDP header by adding one. + + // Invalidate the UDP header checksum field. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) |