summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/conntrack.go146
-rw-r--r--pkg/tcpip/stack/iptables_test.go220
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go270
4 files changed, 453 insertions, 184 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index ead36880f..5d76adac1 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -134,6 +134,7 @@ go_test(
srcs = [
"conntrack_test.go",
"forwarding_test.go",
+ "iptables_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
"nic_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index c489506bb..c51d5c09a 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -209,17 +209,41 @@ type bucket struct {
tuples tupleList
}
-func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) {
- switch pkt.tuple.id().transProto {
+// A netAndTransHeadersFunc returns the network and transport headers found
+// in an ICMP payload. The transport layer's payload will not be returned.
+//
+// May panic if the packet does not hold the transport header.
+type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
+
+func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv4(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the buffer is smaller than
+ // the total length specified in the IPv4 header.
+ transHdr := icmpPayload[netHdr.HeaderLength():]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
+ netHdr := header.IPv6(icmpPayload)
+ // Do not use netHdr.Payload() as we might not hold the full packet
+ // in the ICMP error; Payload() panics if the IP payload is smaller than
+ // the payload length specified in the IPv6 header.
+ transHdr := icmpPayload[header.IPv6MinimumSize:]
+ return netHdr, transHdr[:minTransHdrLen]
+}
+
+func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
+ switch transProto {
case header.TCPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.TCP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
+ return netHeader, header.TCP(transHeaderBytes), true
}
case header.UDPProtocolNumber:
if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
- netHeader := netHdrFunc(netAndTransHeader)
- return netHeader, header.UDP(netHeader.Payload()), true
+ netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
+ return netHeader, header.UDP(transHeaderBytes), true
}
}
return nil, nil, false
@@ -246,7 +270,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic("should have dropped packets with IPv4 options")
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok {
return netHdr, transHdr, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -264,7 +288,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check
panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
}
- if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
return netHdr, transHdr, true, true
}
}
@@ -283,34 +307,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro
}
}
-func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
- switch transProto {
- case header.TCPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.TCP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
- case header.UDPProtocolNumber:
- if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok {
- netHdr := netHdrFunc(netAndTransHeader)
- transHdr := header.UDP(netHdr.Payload())
- return tupleID{
- srcAddr: netHdr.DestinationAddress(),
- srcPort: transHdr.DestinationPort(),
- dstAddr: netHdr.SourceAddress(),
- dstPort: transHdr.SourcePort(),
- transProto: transProto,
- netProto: netProto,
- }, true
- }
+func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
+ if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }, true
}
return tupleID{}, false
@@ -349,7 +355,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
return tupleID{}, false, false
}
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
return tid, true, true
}
case header.ICMPv6ProtocolNumber:
@@ -370,7 +376,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
}
// TODO(https://gvisor.dev/issue/6789): Handle extension headers.
- if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
+ if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
return tid, true, true
}
}
@@ -730,9 +736,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused deletes timed out entries from the conntrack map. The rules for
// reaping are:
-// - Most reaping occurs in connFor, which is called on each packet. connFor
-// cleans up the bucket the packet's connection maps to. Thus calls to
-// reapUnused should be fast.
// - Each call to reapUnused traverses a fraction of the conntrack table.
// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
// - After reaping, reapUnused decides when it should next run based on the
@@ -799,45 +802,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// Precondition: ct.mu is read locked and bkt.mu is write locked.
// +checklocksread:ct.mu
// +checklocks:bkt.mu
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
- if !tuple.conn.timedOut(now) {
+func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
+ if !reapingTuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap both tuples if the reply appears
- // later in the table.
- replyBktID := ct.bucket(tuple.id().reply())
- tuple.conn.mu.RLock()
- replyTupleInserted := tuple.conn.finalized
- tuple.conn.mu.RUnlock()
- if bktID > replyBktID && replyTupleInserted {
- return true
+ var otherTuple *tuple
+ if reapingTuple.reply {
+ otherTuple = &reapingTuple.conn.original
+ } else {
+ otherTuple = &reapingTuple.conn.reply
}
- // Reap the reply.
- if replyTupleInserted {
- // Don't re-lock if both tuples are in the same bucket.
- if bktID != replyBktID {
- replyBkt := &ct.buckets[replyBktID]
- replyBkt.mu.Lock()
- removeConnFromBucket(replyBkt, tuple)
- replyBkt.mu.Unlock()
- } else {
- removeConnFromBucket(bkt, tuple)
- }
+ otherTupleBktID := ct.bucket(otherTuple.id())
+ reapingTuple.conn.mu.RLock()
+ replyTupleInserted := reapingTuple.conn.finalized
+ reapingTuple.conn.mu.RUnlock()
+
+ // To maintain lock order, we can only reap both tuples if the tuple for the
+ // other direction appears later in the table.
+ if bktID > otherTupleBktID && replyTupleInserted {
+ return true
}
- bkt.tuples.Remove(tuple)
- return true
-}
+ bkt.tuples.Remove(reapingTuple)
-// +checklocks:b.mu
-func removeConnFromBucket(b *bucket, tuple *tuple) {
- if tuple.reply {
- b.tuples.Remove(&tuple.conn.original)
+ if !replyTupleInserted {
+ // The other tuple is the reply which has not yet been inserted.
+ return true
+ }
+
+ // Reap the other connection.
+ if bktID == otherTupleBktID {
+ // Don't re-lock if both tuples are in the same bucket.
+ bkt.tuples.Remove(otherTuple)
} else {
- b.tuples.Remove(&tuple.conn.reply)
+ otherTupleBkt := &ct.buckets[otherTupleBktID]
+ otherTupleBkt.mu.Lock()
+ otherTupleBkt.tuples.Remove(otherTuple)
+ otherTupleBkt.mu.Unlock()
}
+
+ return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go
new file mode 100644
index 000000000..1788e98c9
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestNATedConnectionReap tests that NATed connections are properly reaped.
+func TestNATedConnectionReap(t *testing.T) {
+ // Note that the network protocol used for this test doesn't matter as this
+ // test focuses on reaping, not anything related to a specific network
+ // protocol.
+
+ const (
+ nattedDstPort = 1
+ srcPort = 2
+ dstPort = 3
+
+ nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ )
+
+ clock := faketime.NewManualClock()
+ iptables := DefaultTables(0 /* seed */, clock)
+
+ table := Table{
+ Rules: []Rule{
+ // Prerouting
+ {
+ Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort},
+ },
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Target: &AcceptTarget{},
+ },
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 2,
+ Forward: 3,
+ Output: 4,
+ Postrouting: 5,
+ },
+ }
+ if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err)
+ }
+
+ // Stop the reaper if it is running so we can reap manually as it is started
+ // on the first change to IPTables.
+ iptables.reaperDone <- struct{}{}
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize,
+ })
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udp.SetSourcePort(srcPort)
+ udp.SetDestinationPort(dstPort)
+ udp.SetChecksum(0)
+ udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
+ header.UDPProtocolNumber,
+ srcAddr,
+ dstAddr,
+ uint16(len(udp)),
+ )))
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(udp)),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 64,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ })
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ originalTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get original tuple ID")
+ }
+
+ if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckPrerouting(...) = false, want = true")
+ }
+ if !iptables.CheckInput(pkt, "" /* inNicName */) {
+ t.Fatal("got ipt.CheckInput(...) = false, want = true")
+ }
+
+ invertedReplyTID, _, ok := getTupleID(pkt)
+ if !ok {
+ t.Fatal("failed to get NATed packet's tuple ID")
+ }
+ if invertedReplyTID == originalTID {
+ t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID)
+ }
+ replyTID := invertedReplyTID.reply()
+
+ originalBktID := iptables.connections.bucket(originalTID)
+ replyBktID := iptables.connections.bucket(replyTID)
+
+ // This test depends on the original and reply tuples mapping to different
+ // buckets.
+ if originalBktID == replyBktID {
+ t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID)
+ }
+
+ lowerBktID := originalBktID
+ if lowerBktID > replyBktID {
+ lowerBktID = replyBktID
+ }
+
+ runReaper := func() {
+ // Reaping the bucket with the lower ID should reap both tuples of the
+ // connection if it has timed out.
+ //
+ // We will manually pick the next start bucket ID and don't use the
+ // interval so we ignore the return values.
+ _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */)
+ }
+
+ iptables.connections.mu.RLock()
+ buckets := iptables.connections.buckets
+ iptables.connections.mu.RUnlock()
+
+ originalBkt := &buckets[originalBktID]
+ replyBkt := &buckets[replyBktID]
+
+ // Run the reaper and make sure the tuples were not reaped.
+ reapAndCheckForConnections := func() {
+ t.Helper()
+
+ runReaper()
+
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil {
+ t.Error("expected to get original tuple")
+ }
+
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil {
+ t.Error("expected to get reply tuple")
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ // Connection was just added and no time has passed - it should not be reaped.
+ reapAndCheckForConnections()
+
+ // Time must advance past the unestablished timeout for a connection to be
+ // reaped.
+ clock.Advance(unestablishedTimeout)
+ reapAndCheckForConnections()
+
+ // Connection should now be reaped.
+ clock.Advance(1)
+ runReaper()
+ now := clock.NowMonotonic()
+ if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil {
+ t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple)
+ }
+ if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil {
+ t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple)
+ }
+ // Make sure we don't have stale tuples just lying around.
+ //
+ // We manually check the buckets as connForTID will skip over tuples that
+ // have timed out.
+ checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) {
+ t.Helper()
+
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ if tuple.id() == originalTID {
+ t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply)
+ }
+ }
+ }
+ checkNoTupleInBucket(originalBkt, originalTID, false /* reply */)
+ checkNoTupleInBucket(replyBkt, replyTID, true /* reply */)
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 7fe3b29d9..b2383576c 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) {
}
func TestNATICMPError(t *testing.T) {
- const srcPort = 1234
- const dstPort = 5432
+ const (
+ srcPort = 1234
+ dstPort = 5432
+ dataSize = 4
+ )
type icmpTypeTest struct {
name string
@@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) {
netProto: ipv4.ProtocolNumber,
host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)
- hdr := buffer.NewPrependable(totalLen)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
}
@@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) {
icmp.SetType(header.ICMPv4Type(icmpType))
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.ICMPv4ProtocolNumber,
utils.Host1IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
@@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.UDPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.UDPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- totalLen := header.IPv4MinimumSize + header.TCPMinimumSize
- hdr := buffer.NewPrependable(totalLen)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ipHdr(hdr.Prepend(header.IPv4MinimumSize),
- totalLen,
+ ipHdr(
+ hdr.Prepend(header.IPv4MinimumSize),
+ hdr.UsedLength(),
header.TCPProtocolNumber,
utils.Host2IPv4Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
@@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) {
Src: utils.Host1IPv6Addr.AddressWithPrefix.Address,
Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
}))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
payloadLen,
header.ICMPv6ProtocolNumber,
utils.Host1IPv6Addr.AddressWithPrefix.Address,
@@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) {
name: "UDP",
proto: header.UDPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpSize := header.UDPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize)
+ udp := header.UDP(hdr.Prepend(udpSize))
udp.SetSourcePort(srcPort)
udp.SetDestinationPort(dstPort)
udp.SetChecksum(0)
@@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(udp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.UDPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(udp),
header.UDPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) {
name: "TCP",
proto: header.TCPProtocolNumber,
buf: func() buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize)
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcpSize := header.TCPMinimumSize + dataSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize)
+ tcp := header.TCP(hdr.Prepend(tcpSize))
tcp.SetSourcePort(srcPort)
tcp.SetDestinationPort(dstPort)
tcp.SetDataOffset(header.TCPMinimumSize)
@@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) {
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
uint16(len(tcp)),
)))
- ip6Hdr(hdr.Prepend(header.IPv6MinimumSize),
- header.TCPMinimumSize,
+ ip6Hdr(
+ hdr.Prepend(header.IPv6MinimumSize),
+ len(tcp),
header.TCPProtocolNumber,
utils.Host2IPv6Addr.AddressWithPrefix.Address,
utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
@@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) {
},
}
+ trimTests := []struct {
+ name string
+ trimLen int
+ expectNATedICMP bool
+ }{
+ {
+ name: "Trim nothing",
+ trimLen: 0,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data",
+ trimLen: dataSize,
+ expectNATedICMP: true,
+ },
+ {
+ name: "Trim data and transport header",
+ trimLen: dataSize + 1,
+ expectNATedICMP: false,
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, transportType := range test.transportTypes {
t.Run(transportType.name, func(t *testing.T) {
for _, icmpType := range test.icmpTypes {
t.Run(icmpType.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
-
- ep1 := channel.New(1, header.IPv6MinimumMTU, "")
- ep2 := channel.New(1, header.IPv6MinimumMTU, "")
- utils.SetupRouterStack(t, s, ep1, ep2)
-
- ipv6 := test.netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
+ for _, trimTest := range trimTests {
+ t.Run(trimTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+
+ ep1 := channel.New(1, header.IPv6MinimumMTU, "")
+ ep2 := channel.New(1, header.IPv6MinimumMTU, "")
+ utils.SetupRouterStack(t, s, ep1, ep2)
+
+ ipv6 := test.netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
},
- Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
},
- Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
+ }
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
- ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(),
- }))
+ buf := transportType.buf
- {
- pkt, ok := ep1.Read()
- if !ok {
- t.Fatal("expected to read a packet on ep1")
- }
- pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
- transportType.checkNATed(t, pktView)
- if t.Failed() {
- t.FailNow()
- }
+ ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: append(buffer.View(nil), buf...).ToVectorisedView(),
+ }))
- ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
- }))
- }
+ {
+ pkt, ok := ep1.Read()
+ if !ok {
+ t.Fatal("expected to read a packet on ep1")
+ }
+ pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
+ transportType.checkNATed(t, pktView)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ pktView = pktView[:len(pktView)-trimTest.trimLen]
+ buf = buf[:len(buf)-trimTest.trimLen]
+
+ ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
+ }))
+ }
- pkt, ok := ep2.Read()
- if ok != icmpType.expectResponse {
- t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse)
- }
- if !icmpType.expectResponse {
- return
+ pkt, ok := ep2.Read()
+ expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
+ if ok != expectResponse {
+ t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse)
+ }
+ if !expectResponse {
+ return
+ }
+ test.decrementTTL(buf)
+ test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val)
+ })
}
- test.decrementTTL(transportType.buf)
- test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val)
})
}
})