summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/conntrack.go110
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go46
-rw-r--r--pkg/tcpip/stack/packet_buffer.go23
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go288
-rw-r--r--pkg/tcpip/tests/utils/utils.go36
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go2
8 files changed, 420 insertions, 93 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..4fb7e9adb 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
@@ -209,27 +215,38 @@ type bucket struct {
tuples tupleList
}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
+ }
+
+ return nil, false
+}
+
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
+ srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
@@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
- }
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
@@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if conn.manip == manipDestination && dir == dirOriginal {
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
+ pkt.NatDone = true
+ } else if conn.manip == manipSource && dir == dirReply {
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if conn.manip == manipSource && dir == dirOriginal {
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if conn.manip == manipDestination && dir == dirReply {
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
return false
}
@@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
return false
}
@@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber, header.UDPProtocolNumber:
+ default:
+ // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
+ // connections.
return
}
@@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
ct.insertConn(conn)
}
@@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
conn, _ := ct.connForTID(tid)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..3617b6dd0 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -482,11 +482,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..de5997e9e 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -206,34 +206,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
+ port := st.Port
+
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ if port == 0 {
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ }
+ case header.TCPProtocolNumber:
+ if port == 0 {
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
}
+ }
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ // Set up conection for matching NAT rule. Only the first packet of the
+ // connection comes here. Other packets will be manipulated in connection
+ // tracking.
+ //
+ // Does nothing if the protocol does not support connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
return RuleAccept, 0
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index b9280c2de..bf248ef20 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// tell if a noop connection should be inserted at Input hook. Once conntrack
// redefines the manipulation field as mutable, we won't need the special noop
// connection.
- if pk.NatDone {
- newPk.NatDone = true
- }
+ newPk.NatDone = pk.NatDone
return newPk
}
@@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// The returned packet buffer will have the network and transport headers
// set if the original packet buffer did.
func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
- newPkt := NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: reservedHeaderBytes,
Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
IsForwardedPacket: true,
@@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
{
consumeBytes := pk.NetworkHeader().View().Size()
- if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
}
- newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
{
consumeBytes := pk.TransportHeader().View().Size()
- if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed {
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
}
- newPkt.TransportProtocolNumber = pk.TransportProtocolNumber
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- return newPkt
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ newPk.NatDone = pk.NatDone
+
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..7c998eaae 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 28b49c6be..bdf4a64b9 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestSNAT(t *testing.T) {
+ const listenPort = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+
+ nattedClientAddr tcpip.Address
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ {
+ name: "IPv6 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 947bcc7b1..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a3002abf3..407ab2664 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err