diff options
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_test.go | 220 |
3 files changed, 253 insertions, 32 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ead36880f..5d76adac1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -134,6 +134,7 @@ go_test( srcs = [ "conntrack_test.go", "forwarding_test.go", + "iptables_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7fa657001..c51d5c09a 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -736,9 +736,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused deletes timed out entries from the conntrack map. The rules for // reaping are: -// - Most reaping occurs in connFor, which is called on each packet. connFor -// cleans up the bucket the packet's connection maps to. Thus calls to -// reapUnused should be fast. // - Each call to reapUnused traverses a fraction of the conntrack table. // Specifically, it traverses len(ct.buckets)/fractionPerReaping. // - After reaping, reapUnused decides when it should next run based on the @@ -805,45 +802,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // Precondition: ct.mu is read locked and bkt.mu is write locked. // +checklocksread:ct.mu // +checklocks:bkt.mu -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { - if !tuple.conn.timedOut(now) { +func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { + if !reapingTuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap both tuples if the reply appears - // later in the table. - replyBktID := ct.bucket(tuple.id().reply()) - tuple.conn.mu.RLock() - replyTupleInserted := tuple.conn.finalized - tuple.conn.mu.RUnlock() - if bktID > replyBktID && replyTupleInserted { - return true + var otherTuple *tuple + if reapingTuple.reply { + otherTuple = &reapingTuple.conn.original + } else { + otherTuple = &reapingTuple.conn.reply } - // Reap the reply. - if replyTupleInserted { - // Don't re-lock if both tuples are in the same bucket. - if bktID != replyBktID { - replyBkt := &ct.buckets[replyBktID] - replyBkt.mu.Lock() - removeConnFromBucket(replyBkt, tuple) - replyBkt.mu.Unlock() - } else { - removeConnFromBucket(bkt, tuple) - } + otherTupleBktID := ct.bucket(otherTuple.id()) + reapingTuple.conn.mu.RLock() + replyTupleInserted := reapingTuple.conn.finalized + reapingTuple.conn.mu.RUnlock() + + // To maintain lock order, we can only reap both tuples if the tuple for the + // other direction appears later in the table. + if bktID > otherTupleBktID && replyTupleInserted { + return true } - bkt.tuples.Remove(tuple) - return true -} + bkt.tuples.Remove(reapingTuple) + + if !replyTupleInserted { + // The other tuple is the reply which has not yet been inserted. + return true + } -// +checklocks:b.mu -func removeConnFromBucket(b *bucket, tuple *tuple) { - if tuple.reply { - b.tuples.Remove(&tuple.conn.original) + // Reap the other connection. + if bktID == otherTupleBktID { + // Don't re-lock if both tuples are in the same bucket. + bkt.tuples.Remove(otherTuple) } else { - b.tuples.Remove(&tuple.conn.reply) + otherTupleBkt := &ct.buckets[otherTupleBktID] + otherTupleBkt.mu.Lock() + otherTupleBkt.tuples.Remove(otherTuple) + otherTupleBkt.mu.Unlock() } + + return true } func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go new file mode 100644 index 000000000..1788e98c9 --- /dev/null +++ b/pkg/tcpip/stack/iptables_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestNATedConnectionReap tests that NATed connections are properly reaped. +func TestNATedConnectionReap(t *testing.T) { + // Note that the network protocol used for this test doesn't matter as this + // test focuses on reaping, not anything related to a specific network + // protocol. + + const ( + nattedDstPort = 1 + srcPort = 2 + dstPort = 3 + + nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ) + + clock := faketime.NewManualClock() + iptables := DefaultTables(0 /* seed */, clock) + + table := Table{ + Rules: []Rule{ + // Prerouting + { + Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort}, + }, + { + Target: &AcceptTarget{}, + }, + + // Input + { + Target: &AcceptTarget{}, + }, + + // Forward + { + Target: &AcceptTarget{}, + }, + + // Output + { + Target: &AcceptTarget{}, + }, + + // Postrouting + { + Target: &AcceptTarget{}, + }, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 2, + Forward: 3, + Output: 4, + Postrouting: 5, + }, + } + if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err) + } + + // Stop the reaper if it is running so we can reap manually as it is started + // on the first change to IPTables. + iptables.reaperDone <- struct{}{} + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize, + }) + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udp.SetSourcePort(srcPort) + udp.SetDestinationPort(dstPort) + udp.SetChecksum(0) + udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + srcAddr, + dstAddr, + uint16(len(udp)), + ))) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(udp)), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + originalTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get original tuple ID") + } + + if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) { + t.Fatal("got ipt.CheckPrerouting(...) = false, want = true") + } + if !iptables.CheckInput(pkt, "" /* inNicName */) { + t.Fatal("got ipt.CheckInput(...) = false, want = true") + } + + invertedReplyTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get NATed packet's tuple ID") + } + if invertedReplyTID == originalTID { + t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID) + } + replyTID := invertedReplyTID.reply() + + originalBktID := iptables.connections.bucket(originalTID) + replyBktID := iptables.connections.bucket(replyTID) + + // This test depends on the original and reply tuples mapping to different + // buckets. + if originalBktID == replyBktID { + t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID) + } + + lowerBktID := originalBktID + if lowerBktID > replyBktID { + lowerBktID = replyBktID + } + + runReaper := func() { + // Reaping the bucket with the lower ID should reap both tuples of the + // connection if it has timed out. + // + // We will manually pick the next start bucket ID and don't use the + // interval so we ignore the return values. + _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */) + } + + iptables.connections.mu.RLock() + buckets := iptables.connections.buckets + iptables.connections.mu.RUnlock() + + originalBkt := &buckets[originalBktID] + replyBkt := &buckets[replyBktID] + + // Run the reaper and make sure the tuples were not reaped. + reapAndCheckForConnections := func() { + t.Helper() + + runReaper() + + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil { + t.Error("expected to get original tuple") + } + + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil { + t.Error("expected to get reply tuple") + } + + if t.Failed() { + t.FailNow() + } + } + + // Connection was just added and no time has passed - it should not be reaped. + reapAndCheckForConnections() + + // Time must advance past the unestablished timeout for a connection to be + // reaped. + clock.Advance(unestablishedTimeout) + reapAndCheckForConnections() + + // Connection should now be reaped. + clock.Advance(1) + runReaper() + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil { + t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple) + } + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil { + t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple) + } + // Make sure we don't have stale tuples just lying around. + // + // We manually check the buckets as connForTID will skip over tuples that + // have timed out. + checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) { + t.Helper() + + bkt.mu.RLock() + defer bkt.mu.RUnlock() + for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() { + if tuple.id() == originalTID { + t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply) + } + } + } + checkNoTupleInBucket(originalBkt, originalTID, false /* reply */) + checkNoTupleInBucket(replyBkt, replyTID, true /* reply */) +} |