From d8772545113ff941d34a4eae5f43df3f39d3547f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 22 Sep 2021 17:52:43 -0700 Subject: Track UDP connections This will enable NAT to be performed on UDP packets that are sent in response to packets sent by the stack. This will also enable ICMP errors to be properly NAT-ed in response to UDP packets (#5916). Updates #5915. PiperOrigin-RevId: 398373251 --- pkg/tcpip/stack/conntrack.go | 110 +++++----- pkg/tcpip/stack/iptables.go | 4 +- pkg/tcpip/stack/iptables_targets.go | 46 ++--- pkg/tcpip/stack/packet_buffer.go | 23 ++- pkg/tcpip/tests/integration/BUILD | 4 + pkg/tcpip/tests/integration/iptables_test.go | 288 +++++++++++++++++++++++++++ pkg/tcpip/tests/utils/utils.go | 36 +++- pkg/tcpip/transport/tcp/endpoint.go | 2 +- 8 files changed, 420 insertions(+), 93 deletions(-) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..4fb7e9adb 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // // Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. @@ -209,27 +215,38 @@ type bucket struct { tuples tupleList } +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } + } + + return nil, false +} + // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), + srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, }, nil } @@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } @@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false - } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the @@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if conn.manip == manipDestination && dir == dirOriginal { + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr + pkt.NatDone = true + } else if conn.manip == manipSource && dir == dirReply { + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if conn.manip == manipSource && dir == dirOriginal { + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if conn.manip == manipDestination && dir == dirReply { + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { return false } @@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) return false } @@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber, header.UDPProtocolNumber: + default: + // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable + // connections. return } @@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) ct.insertConn(conn) } @@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } conn, _ := ct.connForTID(tid) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..3617b6dd0 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -482,11 +482,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..de5997e9e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -206,34 +206,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 + port := st.Port + + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + if port == 0 { + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + } + case header.TCPProtocolNumber: + if port == 0 { + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + } } + } - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 + // Set up conection for matching NAT rule. Only the first packet of the + // connection comes here. Other packets will be manipulated in connection + // tracking. + // + // Does nothing if the protocol does not support connection tracking. + if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } return RuleAccept, 0 diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index b9280c2de..bf248ef20 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // tell if a noop connection should be inserted at Input hook. Once conntrack // redefines the manipulation field as mutable, we won't need the special noop // connection. - if pk.NatDone { - newPk.NatDone = true - } + newPk.NatDone = pk.NatDone return newPk } @@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // The returned packet buffer will have the network and transport headers // set if the original packet buffer did. func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { - newPkt := NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: reservedHeaderBytes, Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), IsForwardedPacket: true, @@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu { consumeBytes := pk.NetworkHeader().View().Size() - if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) } - newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } { consumeBytes := pk.TransportHeader().View().Size() - if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) } - newPkt.TransportProtocolNumber = pk.TransportProtocolNumber + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - return newPkt + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + newPk.NatDone = pk.NatDone + + return newPk } // headerInfo stores metadata about a header in a packet. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..7c998eaae 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 28b49c6be..bdf4a64b9 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestSNAT(t *testing.T) { + const listenPort = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + + nattedClientAddr tcpip.Address + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + tests := []struct { + name string + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, + } + }, + }, + { + name: "IPv6 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, + } + }, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } + } + nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + nattedClientAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, nattedClientAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 947bcc7b1..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a3002abf3..407ab2664 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err -- cgit v1.2.3