summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-11-02 13:14:23 -0700
committergVisor bot <gvisor-bot@google.com>2021-11-02 13:18:11 -0700
commitea792cb3e1b3c1f2c34b2ffd7dbfde5d935b8a74 (patch)
tree9c3f9256c425556e59a8d801b532d187913d0022
parent1e1d6b2be37873c5e62461834df973f41565c662 (diff)
Properly reap NATed connections
This change fixes a bug when reaping tuples of NAT-ed connections. Previously when reaping a tuple, the other direction's tuple ID was calculated by taking the reaping tuple's ID and inverting it. This works when a connection is not NATed but doesn't work when NAT is performed as the other direction's tuple may use different addresses. PiperOrigin-RevId: 407160930
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go64
-rw-r--r--pkg/tcpip/stack/iptables_test.go220
3 files changed, 253 insertions, 32 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index ead36880f..5d76adac1 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -134,6 +134,7 @@ go_test(
srcs = [
"conntrack_test.go",
"forwarding_test.go",
+ "iptables_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
"nic_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7fa657001..c51d5c09a 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -736,9 +736,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused deletes timed out entries from the conntrack map. The rules for
// reaping are:
-// - Most reaping occurs in connFor, which is called on each packet. connFor
-// cleans up the bucket the packet's connection maps to. Thus calls to
-// reapUnused should be fast.
// - Each call to reapUnused traverses a fraction of the conntrack table.
// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
// - After reaping, reapUnused decides when it should next run based on the
@@ -805,45 +802,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// Precondition: ct.mu is read locked and bkt.mu is write locked.
// +checklocksread:ct.mu
// +checklocks:bkt.mu
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
- if !tuple.conn.timedOut(now) {
+func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
+ if !reapingTuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap both tuples if the reply appears
- // later in the table.
- replyBktID := ct.bucket(tuple.id().reply())
- tuple.conn.mu.RLock()
- replyTupleInserted := tuple.conn.finalized
- tuple.conn.mu.RUnlock()
- if bktID > replyBktID && replyTupleInserted {
- return true
+ var otherTuple *tuple
+ if reapingTuple.reply {
+ otherTuple = &reapingTuple.conn.original
+ } else {
+ otherTuple = &reapingTuple.conn.reply
}
- // Reap the reply.
- if replyTupleInserted {
- // Don't re-lock if both tuples are in the same bucket.
- if bktID != replyBktID {
- replyBkt := &ct.buckets[replyBktID]
- replyBkt.mu.Lock()
- removeConnFromBucket(replyBkt, tuple)
- replyBkt.mu.Unlock()
- } else {
- removeConnFromBucket(bkt, tuple)
- }
+ otherTupleBktID := ct.bucket(otherTuple.id())
+ reapingTuple.conn.mu.RLock()
+ replyTupleInserted := reapingTuple.conn.finalized
+ reapingTuple.conn.mu.RUnlock()
+
+ // To maintain lock order, we can only reap both tuples if the tuple for the
+ // other direction appears later in the table.
+ if bktID > otherTupleBktID && replyTupleInserted {
+ return true
}
- bkt.tuples.Remove(tuple)
- return true
-}
+ bkt.tuples.Remove(reapingTuple)
+
+ if !replyTupleInserted {
+ // The other tuple is the reply which has not yet been inserted.
+ return true
+ }
-// +checklocks:b.mu
-func removeConnFromBucket(b *bucket, tuple *tuple) {
- if tuple.reply {
- b.tuples.Remove(&tuple.conn.original)
+ // Reap the other connection.
+ if bktID == otherTupleBktID {
+ // Don't re-lock if both tuples are in the same bucket.
+ bkt.tuples.Remove(otherTuple)
} else {
- b.tuples.Remove(&tuple.conn.reply)
+ otherTupleBkt := &ct.buckets[otherTupleBktID]
+ otherTupleBkt.mu.Lock()
+ otherTupleBkt.tuples.Remove(otherTuple)
+ otherTupleBkt.mu.Unlock()
}
+
+ return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go
new file mode 100644
index 000000000..1788e98c9
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestNATedConnectionReap tests that NATed connections are properly reaped.
+func TestNATedConnectionReap(t *testing.T) {
+ // Note that the network protocol used for this test doesn't matter as this
+ // test focuses on reaping, not anything related to a specific network
+ // protocol.
+
+ const (
+ nattedDstPort = 1
+ srcPort = 2
+ dstPort = 3
+
+ nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ )
+
+ clock := faketime.NewManualClock()
+ iptables := DefaultTables(0 /* seed */, clock)
+
+ table := Table{
+ Rules: []Rule{
+ // Prerouting
+ {
+ Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort},
+ },
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Target: &AcceptTarget{},
+ },
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 2,
+ Forward: 3,
+ Output: 4,
+ Postrouting: 5,
+ },
+ }
+ if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err)
+ }
+
+ // Stop the reaper if it is running so we can reap manually as it is started
+ // on the first change to IPTables.
+ iptables.reaperDone <- struct{}{}
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize,
+ })
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udp.SetSourcePort(srcPort)
+ udp.SetDestinationPort(dstPort)
+ udp.SetChecksum(0)
+ udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
+ header.UDPProtocolNumber,
+ srcAddr,
+ dstAddr,
+ uint16(len(udp)),
+ )))
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(udp)),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 64,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ })
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ originalTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get original tuple ID")
+ }
+
+ if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckPrerouting(...) = false, want = true")
+ }
+ if !iptables.CheckInput(pkt, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckInput(...) = false, want = true")
+ }
+
+ invertedReplyTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get NATed packet's tuple ID")
+ }
+ if invertedReplyTID == originalTID {
+ t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID)
+ }
+ replyTID := invertedReplyTID.reply()
+
+ originalBktID := iptables.connections.bucket(originalTID)
+ replyBktID := iptables.connections.bucket(replyTID)
+
+ // This test depends on the original and reply tuples mapping to different
+ // buckets.
+ if originalBktID == replyBktID {
+ t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID)
+ }
+
+ lowerBktID := originalBktID
+ if lowerBktID > replyBktID {
+ lowerBktID = replyBktID
+ }
+
+ runReaper := func() {
+ // Reaping the bucket with the lower ID should reap both tuples of the
+ // connection if it has timed out.
+ //
+ // We will manually pick the next start bucket ID and don't use the
+ // interval so we ignore the return values.
+ _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */)
+ }
+
+ iptables.connections.mu.RLock()
+ buckets := iptables.connections.buckets
+ iptables.connections.mu.RUnlock()
+
+ originalBkt := &buckets[originalBktID]
+ replyBkt := &buckets[replyBktID]
+
+ // Run the reaper and make sure the tuples were not reaped.
+ reapAndCheckForConnections := func() {
+ t.Helper()
+
+ runReaper()
+
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil {
+ t.Error("expected to get original tuple")
+ }
+
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil {
+ t.Error("expected to get reply tuple")
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ // Connection was just added and no time has passed - it should not be reaped.
+ reapAndCheckForConnections()
+
+ // Time must advance past the unestablished timeout for a connection to be
+ // reaped.
+ clock.Advance(unestablishedTimeout)
+ reapAndCheckForConnections()
+
+ // Connection should now be reaped.
+ clock.Advance(1)
+ runReaper()
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil {
+ t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple)
+ }
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil {
+ t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple)
+ }
+ // Make sure we don't have stale tuples just lying around.
+ //
+ // We manually check the buckets as connForTID will skip over tuples that
+ // have timed out.
+ checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) {
+ t.Helper()
+
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ if tuple.id() == originalTID {
+ t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply)
+ }
+ }
+ }
+ checkNoTupleInBucket(originalBkt, originalTID, false /* reply */)
+ checkNoTupleInBucket(replyBkt, replyTID, true /* reply */)
+}