diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack_test.go | 132 |
3 files changed, 163 insertions, 22 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6999add78..ead36880f 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -132,6 +132,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "conntrack_test.go", "forwarding_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3583d93c6..89f8ef09f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -37,6 +37,11 @@ import ( // Our hash table has 16K buckets. const numBuckets = 1 << 14 +const ( + establishedTimeout time.Duration = 5 * 24 * time.Hour + unestablishedTimeout time.Duration = 120 * time.Second +) + // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. // @@ -128,8 +133,6 @@ type conn struct { // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { - const establishedTimeout = 5 * 24 * time.Hour - const defaultTimeout = 120 * time.Second cn.mu.RLock() defer cn.mu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { @@ -139,7 +142,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { } // Use the same default as Linux, which lets connections in most states // other than established remain for <= 120 seconds. - return now.Sub(cn.lastUsed) > defaultTimeout + return now.Sub(cn.lastUsed) > unestablishedTimeout } // update the connection tracking state. @@ -403,7 +406,7 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) // has had NAT performed on it. // // Returns true if the packet can skip the NAT table. -func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { transportHeader, ok := getTransportHeader(pkt) if !ok { return false @@ -432,7 +435,7 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { case Postrouting: if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true - } else if r.RequiresTXTransportChecksum() { + } else if rt.RequiresTXTransportChecksum() { fullChecksum = true updatePseudoHeader = true } @@ -453,6 +456,11 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { cn.mu.Lock() defer cn.mu.Unlock() + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + var tuple *tuple if reply { if dnat { @@ -476,11 +484,6 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { tuple = &cn.reply } - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = cn.ct.clock.NowMonotonic() - // Update connection state. - cn.updateLocked(pkt, reply) - return tuple.id(), true }() if !performManip { @@ -598,24 +601,29 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t return false } - // To maintain lock order, we can only reap these tuples if the reply - // appears later in the table. + // To maintain lock order, we can only reap both tuples if the reply appears + // later in the table. replyBktID := ct.bucket(tuple.id().reply()) - if bktID > replyBktID { + tuple.conn.mu.RLock() + replyTupleInserted := tuple.conn.finalized + tuple.conn.mu.RUnlock() + if bktID > replyBktID && replyTupleInserted { return true } - // Don't re-lock if both tuples are in the same bucket. - if bktID != replyBktID { - replyBkt := &ct.buckets[replyBktID] - replyBkt.mu.Lock() - removeConnFromBucket(replyBkt, tuple) - replyBkt.mu.Unlock() - } else { - removeConnFromBucket(bkt, tuple) + // Reap the reply. + if replyTupleInserted { + // Don't re-lock if both tuples are in the same bucket. + if bktID != replyBktID { + replyBkt := &ct.buckets[replyBktID] + replyBkt.mu.Lock() + removeConnFromBucket(replyBkt, tuple) + replyBkt.mu.Unlock() + } else { + removeConnFromBucket(bkt, tuple) + } } - // We have the buckets locked and can remove both tuples. bkt.tuples.Remove(tuple) return true } diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go new file mode 100644 index 000000000..fb0645ed1 --- /dev/null +++ b/pkg/tcpip/stack/conntrack_test.go @@ -0,0 +1,132 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" +) + +func TestReap(t *testing.T) { + // Initialize conntrack. + clock := faketime.NewManualClock() + ct := ConnTrack{ + clock: clock, + } + ct.init() + ct.checkNumTuples(t, 0) + + // Simulate sending a SYN. This will get the connection into conntrack, but + // the connection won't be considered established. Thus the timeout for + // reaping is unestablishedTimeout. + pkt1 := genTCPPacket() + pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1) + // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls + // rt.RequiresTXTransportChecksum. + var rt Route + rt.routeInfo.Loop = PacketLoop + if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel a little into the future and send the same SYN. This should update + // lastUsed, but per #6748 didn't. + clock.Advance(unestablishedTimeout / 2) + pkt2 := genTCPPacket() + pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2) + if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) { + t.Fatal("handlePacket() shouldn't perform any NAT") + } + ct.checkNumTuples(t, 1) + + // Travel farther into the future - enough that failing to update lastUsed + // would cause a reaping - and reap the whole table. Make sure the connection + // hasn't been reaped. + clock.Advance(unestablishedTimeout * 3 / 4) + ct.reapEverything() + ct.checkNumTuples(t, 1) + + // Travel past unestablishedTimeout to confirm the tuple is gone. + clock.Advance(unestablishedTimeout / 2) + ct.reapEverything() + ct.checkNumTuples(t, 0) +} + +// genTCPPacket returns an initialized IPv4 TCP packet. +func genTCPPacket() *PacketBuffer { + const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: packetLen, + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.TransportProtocolNumber = header.TCPProtocolNumber + tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize)) + tcpHdr.Encode(&header.TCPFields{ + SrcPort: 5555, + DstPort: 6666, + SeqNum: 7777, + AckNum: 8888, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: packetLen, + Protocol: uint8(header.TCPProtocolNumber), + SrcAddr: testutil.MustParse4("1.0.0.1"), + DstAddr: testutil.MustParse4("1.0.0.2"), + Checksum: 0, // Conntrack doesn't verify the checksum. + }) + + return pkt +} + +// checkNumTuples checks that there are exactly want tuples tracked by +// conntrack. +func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) { + t.Helper() + ct.mu.RLock() + defer ct.mu.RUnlock() + + var total int + for idx := range ct.buckets { + ct.buckets[idx].mu.RLock() + total += ct.buckets[idx].tuples.Len() + ct.buckets[idx].mu.RUnlock() + } + + if total != want { + t.Fatalf("checkNumTuples: got %d, wanted %d", total, want) + } +} + +func (ct *ConnTrack) reapEverything() { + var bucket int + for { + newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */) + // We started reaping at bucket 0. If the next bucket isn't after our + // current bucket, we've gone through them all. + if newBucket <= bucket { + break + } + bucket = newBucket + } +} |