summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go52
-rw-r--r--pkg/tcpip/stack/conntrack_test.go132
3 files changed, 163 insertions, 22 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6999add78..ead36880f 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -132,6 +132,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "conntrack_test.go",
"forwarding_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 3583d93c6..89f8ef09f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -37,6 +37,11 @@ import (
// Our hash table has 16K buckets.
const numBuckets = 1 << 14
+const (
+ establishedTimeout time.Duration = 5 * 24 * time.Hour
+ unestablishedTimeout time.Duration = 120 * time.Second
+)
+
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
//
@@ -128,8 +133,6 @@ type conn struct {
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
- const establishedTimeout = 5 * 24 * time.Hour
- const defaultTimeout = 120 * time.Second
cn.mu.RLock()
defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
@@ -139,7 +142,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
}
// Use the same default as Linux, which lets connections in most states
// other than established remain for <= 120 seconds.
- return now.Sub(cn.lastUsed) > defaultTimeout
+ return now.Sub(cn.lastUsed) > unestablishedTimeout
}
// update the connection tracking state.
@@ -403,7 +406,7 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
// has had NAT performed on it.
//
// Returns true if the packet can skip the NAT table.
-func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
transportHeader, ok := getTransportHeader(pkt)
if !ok {
return false
@@ -432,7 +435,7 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
case Postrouting:
if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
- } else if r.RequiresTXTransportChecksum() {
+ } else if rt.RequiresTXTransportChecksum() {
fullChecksum = true
updatePseudoHeader = true
}
@@ -453,6 +456,11 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
cn.mu.Lock()
defer cn.mu.Unlock()
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = cn.ct.clock.NowMonotonic()
+ // Update connection state.
+ cn.updateLocked(pkt, reply)
+
var tuple *tuple
if reply {
if dnat {
@@ -476,11 +484,6 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
tuple = &cn.reply
}
- // Mark the connection as having been used recently so it isn't reaped.
- cn.lastUsed = cn.ct.clock.NowMonotonic()
- // Update connection state.
- cn.updateLocked(pkt, reply)
-
return tuple.id(), true
}()
if !performManip {
@@ -598,24 +601,29 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t
return false
}
- // To maintain lock order, we can only reap these tuples if the reply
- // appears later in the table.
+ // To maintain lock order, we can only reap both tuples if the reply appears
+ // later in the table.
replyBktID := ct.bucket(tuple.id().reply())
- if bktID > replyBktID {
+ tuple.conn.mu.RLock()
+ replyTupleInserted := tuple.conn.finalized
+ tuple.conn.mu.RUnlock()
+ if bktID > replyBktID && replyTupleInserted {
return true
}
- // Don't re-lock if both tuples are in the same bucket.
- if bktID != replyBktID {
- replyBkt := &ct.buckets[replyBktID]
- replyBkt.mu.Lock()
- removeConnFromBucket(replyBkt, tuple)
- replyBkt.mu.Unlock()
- } else {
- removeConnFromBucket(bkt, tuple)
+ // Reap the reply.
+ if replyTupleInserted {
+ // Don't re-lock if both tuples are in the same bucket.
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
+ }
}
- // We have the buckets locked and can remove both tuples.
bkt.tuples.Remove(tuple)
return true
}
diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go
new file mode 100644
index 000000000..fb0645ed1
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack_test.go
@@ -0,0 +1,132 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+)
+
+func TestReap(t *testing.T) {
+ // Initialize conntrack.
+ clock := faketime.NewManualClock()
+ ct := ConnTrack{
+ clock: clock,
+ }
+ ct.init()
+ ct.checkNumTuples(t, 0)
+
+ // Simulate sending a SYN. This will get the connection into conntrack, but
+ // the connection won't be considered established. Thus the timeout for
+ // reaping is unestablishedTimeout.
+ pkt1 := genTCPPacket()
+ pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1)
+ // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls
+ // rt.RequiresTXTransportChecksum.
+ var rt Route
+ rt.routeInfo.Loop = PacketLoop
+ if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel a little into the future and send the same SYN. This should update
+ // lastUsed, but per #6748 didn't.
+ clock.Advance(unestablishedTimeout / 2)
+ pkt2 := genTCPPacket()
+ pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2)
+ if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel farther into the future - enough that failing to update lastUsed
+ // would cause a reaping - and reap the whole table. Make sure the connection
+ // hasn't been reaped.
+ clock.Advance(unestablishedTimeout * 3 / 4)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 1)
+
+ // Travel past unestablishedTimeout to confirm the tuple is gone.
+ clock.Advance(unestablishedTimeout / 2)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 0)
+}
+
+// genTCPPacket returns an initialized IPv4 TCP packet.
+func genTCPPacket() *PacketBuffer {
+ const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: packetLen,
+ })
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
+ tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize))
+ tcpHdr.Encode(&header.TCPFields{
+ SrcPort: 5555,
+ DstPort: 6666,
+ SeqNum: 7777,
+ AckNum: 8888,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+ ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ ipHdr.Encode(&header.IPv4Fields{
+ TotalLength: packetLen,
+ Protocol: uint8(header.TCPProtocolNumber),
+ SrcAddr: testutil.MustParse4("1.0.0.1"),
+ DstAddr: testutil.MustParse4("1.0.0.2"),
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+
+ return pkt
+}
+
+// checkNumTuples checks that there are exactly want tuples tracked by
+// conntrack.
+func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) {
+ t.Helper()
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+
+ var total int
+ for idx := range ct.buckets {
+ ct.buckets[idx].mu.RLock()
+ total += ct.buckets[idx].tuples.Len()
+ ct.buckets[idx].mu.RUnlock()
+ }
+
+ if total != want {
+ t.Fatalf("checkNumTuples: got %d, wanted %d", total, want)
+ }
+}
+
+func (ct *ConnTrack) reapEverything() {
+ var bucket int
+ for {
+ newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */)
+ // We started reaping at bucket 0. If the next bucket isn't after our
+ // current bucket, we've gone through them all.
+ if newBucket <= bucket {
+ break
+ }
+ bucket = newBucket
+ }
+}