diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-01-13 16:02:26 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-13 16:04:33 -0800 |
commit | 25b5ec7135a6de80674ac1ad4d2289c29e156f42 (patch) | |
tree | 5939fd3fd2e4afb2de9c4062cfd9066c544ab46e /pkg/tcpip | |
parent | 1efe0ebc5973ec8a06b881c087dae2183898504b (diff) |
Do not resolve remote link address at transport layer
Link address resolution is performed at the link layer (if required) so
we can defer it from the transport layer. When link resolution is
required, packets will be queued and sent once link resolution
completes. If link resolution fails, the transport layer will receive a
control message indicating that the stack failed to route the packet.
tcpip.Endpoint.Write no longer returns a channel now that writes do not
wait for link resolution at the transport layer.
tcpip.ErrNoLinkAddress is no longer used so it is removed.
Removed calls to stack.Route.ResolveWith from the transport layer so
that link resolution is performed when a route is created in response
to an incoming packet (e.g. to complete TCP handshakes or send a RST).
Tests:
- integration_test.TestForwarding
- integration_test.TestTCPLinkResolutionFailure
Fixes #4458
RELNOTES: n/a
PiperOrigin-RevId: 351684158
Diffstat (limited to 'pkg/tcpip')
33 files changed, 714 insertions, 581 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 7193f56ad..85a0b8b90 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -397,22 +397,9 @@ func (c *TCPConn) Write(b []byte) (int, error) { } var n int64 - var resCh <-chan struct{} - n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) nbytes += int(n) v.TrimFront(int(n)) - - if resCh != nil { - select { - case <-deadline: - return nbytes, c.newOpError("write", &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } } if err == nil { @@ -666,17 +653,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { v := buffer.NewView(len(b)) copy(v, b) - n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if resCh != nil { - select { - case <-deadline: - return int(n), c.newRemoteOpError("write", addr, &timeoutError{}) - case <-resCh: - } - - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - } - + n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -689,7 +666,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e9ff70d04..cc045c7a9 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -64,6 +64,7 @@ const ( var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() +var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -87,6 +88,21 @@ type endpoint struct { } } +// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. +func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // handleControl expects the entire offending packet to be in the packet + // buffer's data field. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + pkt.NICID = e.nic.ID() + pkt.NetworkProtocolNumber = ProtocolNumber + // Use the same control type as an ICMPv4 destination host unreachable error + // since the host is considered unreachable if we cannot resolve the link + // address to the next hop. + e.handleControl(stack.ControlNoRoute, 0, pkt) +} + // NewEndpoint creates a new ipv4 endpoint. func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bbce1ef78..0ec0a0fef 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -645,29 +645,18 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}) - if resCh != nil { - if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err) - } - for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, - {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, - } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { - if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { - t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) - } - }) + if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, + } { + routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { + t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } - <-resCh - continue - } - if err != nil { - t.Fatalf("ep.Write(_) = (_, _, %s)", err) - } - break + }) } for _, args := range []routeArgs{ diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f2018d073..2f82c3d5f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -163,6 +163,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -224,6 +225,18 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. +func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // handleControl expects the entire offending packet to be in the packet + // buffer's data field. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + pkt.NICID = e.nic.ID() + pkt.NetworkProtocolNumber = ProtocolNumber + e.handleControl(stack.ControlAddressUnreachable, 0, pkt) +} + // onAddressAssignedLocked handles an address being assigned. // // Precondition: e.mu must be exclusively locked. diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 4777163cd..a7da9dcd9 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -82,7 +82,7 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { v.CapLength(n) for len(v) > 0 { - n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) if err != nil { fmt.Println("Write failed:", err) return diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 6883045b5..03b2f2d6f 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -83,7 +83,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe got, ch, err := c.get(addr, linkRes, "", nil, nil) if err == tcpip.ErrWouldBlock { if attemptedResolution { - return got, tcpip.ErrNoLinkAddress + return got, tcpip.ErrTimeout } attemptedResolution = true <-ch @@ -253,8 +253,8 @@ func TestCacheResolutionFailed(t *testing.T) { before := atomic.LoadUint32(&requestCount) e.addr.Addr += "2" - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -269,8 +269,8 @@ func TestCacheResolutionTimeout(t *testing.T) { linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { + t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) } } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 664cc6fa0..5f216ca21 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -268,17 +268,6 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { } } -// SourceLinkAddress returns the source link address of the packet. -func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { - link := pk.LinkHeader().View() - - if link.IsEmpty() { - return "" - } - - return header.Ethernet(link).SourceAddress() -} - // Network returns the network header as a header.Network. // // Network should only be called when NetworkHeader has been set. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 4a3adcf33..bded8814e 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -101,10 +101,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } for _, p := range packets { - if cancelled { - p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if p.route.IsResolutionRequired() { + if cancelled || p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() + + if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { + linkResolvableEP.HandleLinkResolutionFailure(pkt) + } } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4795208b4..924790779 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,7 +55,19 @@ type ControlType int // The following are the allowed values for ControlType values. // TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlNetworkUnreachable ControlType = iota + // ControlAddressUnreachable indicates that an IPv6 packet did not reach its + // destination as the destination address was unreachable. + // + // This maps to the ICMPv6 Destination Ureachable Code 3 error; see + // RFC 4443 section 3.1 for more details. + ControlAddressUnreachable ControlType = iota + ControlNetworkUnreachable + // ControlNoRoute indicates that an IPv4 packet did not reach its destination + // because the destination host was unreachable. + // + // This maps to the ICMPv4 Destination Ureachable Code 1 error; see + // RFC 791's Destination Unreachable Message section (page 4) for more + // details. ControlNoRoute ControlPacketTooBig ControlPortUnreachable @@ -503,6 +515,13 @@ type NetworkInterface interface { WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } +// LinkResolvableNetworkEndpoint handles link resolution events. +type LinkResolvableNetworkEndpoint interface { + // HandleLinkResolutionFailure is called when link resolution prevents the + // argument from having been sent. + HandleLinkResolutionFailure(*PacketBuffer) +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 0ff32c6ea..a2ab7537c 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -90,14 +90,14 @@ func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.Rea return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, @@ -105,10 +105,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // SetSockOpt sets a socket option. Currently not supported. @@ -222,7 +222,6 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * if err != nil { return } - route.ResolveWith(pkt.SourceLinkAddress()) ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -522,8 +521,7 @@ func TestTransportSend(t *testing.T) { // Create buffer that will hold the payload. view := buffer.NewView(30) - _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f798056c0..002ddaf67 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -104,7 +104,6 @@ var ( ErrConnectionAborted = &Error{msg: "connection aborted"} ErrNoSuchFile = &Error{msg: "no such file"} ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrNoLinkAddress = &Error{msg: "no remote link address"} ErrBadAddress = &Error{msg: "bad address"} ErrNetworkUnreachable = &Error{msg: "network is unreachable"} ErrMessageTooLong = &Error{msg: "message too long"} @@ -154,7 +153,6 @@ func StringToError(s string) *Error { ErrConnectionAborted, ErrNoSuchFile, ErrInvalidOptionValue, - ErrNoLinkAddress, ErrBadAddress, ErrNetworkUnreachable, ErrMessageTooLong, @@ -640,12 +638,7 @@ type Endpoint interface { // stream (TCP) Endpoints may return partial writes, and even then only // in the case where writing additional data would block. Other Endpoints // will either write the entire message or return an error. - // - // For UDP and Ping sockets if address resolution is required, - // ErrNoLinkAddress and a notification channel is returned for the caller to - // block. Channel is closed once address resolution is complete (success or - // not). The channel is only non-nil in this case. - Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -1784,9 +1777,6 @@ type SendErrors struct { // NoRoute is the number of times we failed to resolve IP route. NoRoute StatCounter - - // NoLinkAddr is the number of times we failed to resolve ARP. - NoLinkAddr StatCounter } // ReadErrors collects segment read errors from an endpoint read call. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ca1e88e99..1742a178d 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -31,5 +31,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 4c2084d19..49acd504e 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -157,13 +158,13 @@ func TestForwarding(t *testing.T) { tests := []struct { name string - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses }{ { name: "IPv4 host1 server with host2 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host1IPv4Addr.AddressWithPrefix.Address, @@ -177,9 +178,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv6 host2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host2IPv6Addr.AddressWithPrefix.Address, @@ -193,9 +194,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv4 host2 server with routerNIC1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: host2IPv4Addr.AddressWithPrefix.Address, @@ -209,9 +210,9 @@ func TestForwarding(t *testing.T) { }, { name: "IPv6 routerNIC2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { - ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) return endpointAndAddresses{ serverEP: ep1, serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, @@ -225,202 +226,270 @@ func TestForwarding(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack) - defer epsAndAddrs.serverEP.Close() - defer epsAndAddrs.clientEP.Close() - - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) { + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr *tcpip.Error + setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerSide: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() - dataPayload := tcpip.SlicePayload(data) - wOpts := tcpip.WriteOptions{To: to} - n, ch, err := ep.Write(dataPayload, wOpts) - if err == tcpip.ErrNoLinkAddress { - // Wait for link resolution to complete. - <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } - if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) } - } - - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data, &serverAddr) - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress { + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: tcpip.ErrConnectStarted, + setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() - // Wait for the endpoint to be readable. - <-ch - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, len(data), opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) } - - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(data), - Total: len(data), - RemoteAddr: tcpip.FullAddress{Addr: expectedFrom}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if err == tcpip.ErrWouldBlock { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + return newEP, newCH } + }, + needRemoteAddr: false, + }, + } - return res.RemoteAddr - } - - addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) - // Unspecify the NIC since NIC IDs are meaningless across stacks. - addr.NIC = 0 - - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) - write(epsAndAddrs.serverEP, data, &addr) - addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr) - if addr.Port != listenPort { - t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + defer epsAndAddrs.serverEP.Close() + defer epsAndAddrs.clientEP.Close() + + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { + t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + } + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + clientAddr = addr + clientAddr.NIC = 0 + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerSide(t, serverEP, serverCH, clientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + dataPayload := tcpip.SlicePayload(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(dataPayload, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + // Wait for the endpoint to be readable. + <-ch + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err := ep.Read(&buf, len(data), opts) + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + read(serverCH, serverEP, data, clientAddr) + + data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + }) } }) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index b4bffaec1..ed00c90d4 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +56,13 @@ var ( PrefixLen: 8, }, } + ipv4Addr3 = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()), + PrefixLen: 8, + }, + } ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -68,8 +77,65 @@ var ( PrefixLen: 64, }, } + ipv6Addr3 = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::3").To16()), + PrefixLen: 64, + }, + } ) +func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) { + host1Stack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + }) + + return host1Stack, host2Stack +} + // TestPing tests that two hosts can ping eachother when link resolution is // enabled. func TestPing(t *testing.T) { @@ -128,51 +194,7 @@ func TestPing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, } - host1Stack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: ipv6Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: ipv4Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: ipv6Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - }) + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) var wq waiter.Queue we, waiterCH := waiter.NewChannelEntry(nil) @@ -183,19 +205,12 @@ func TestPing(t *testing.T) { } defer ep.Close() - // The first write should trigger link resolution. icmpBuf := test.icmpBuf(t) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress) - } else { - // Wait for link resolution to complete. - <-ch - } - if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) } // Wait for the endpoint to be readable. @@ -224,3 +239,159 @@ func TestPing(t *testing.T) { }) } } + +func TestTCPLinkResolutionFailure(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr *tcpip.Error + sockError tcpip.SockError + }{ + { + name: "IPv4 with resolvable remote", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6 with resolvable remote", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv4 without resolvable remote", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedWriteErr: tcpip.ErrNoRoute, + sockError: tcpip.SockError{ + Err: tcpip.ErrNoRoute, + ErrType: byte(header.ICMPv4DstUnreachable), + ErrCode: byte(header.ICMPv4HostUnreachable), + ErrOrigin: tcpip.SockExtErrorOriginICMP, + Dst: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv4Addr3.AddressWithPrefix.Address, + Port: 1234, + }, + Offender: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv4Addr1.AddressWithPrefix.Address, + }, + NetProto: ipv4.ProtocolNumber, + }, + }, + { + name: "IPv6 without resolvable remote", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedWriteErr: tcpip.ErrNoRoute, + sockError: tcpip.SockError{ + Err: tcpip.ErrNoRoute, + ErrType: byte(header.ICMPv6DstUnreachable), + ErrCode: byte(header.ICMPv6AddressUnreachable), + ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Dst: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv6Addr3.AddressWithPrefix.Address, + Port: 1234, + }, + Offender: tcpip.FullAddress{ + NIC: host1NICID, + Addr: ipv6Addr1.AddressWithPrefix.Address, + }, + NetProto: ipv6.ProtocolNumber, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) + } + defer listenerEP.Close() + + listenerAddr := tcpip.FullAddress{Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + + var clientWQ waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&we, waiter.EventOut|waiter.EventErr) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) + } + defer clientEP.Close() + + sockOpts := clientEP.SocketOptions() + sockOpts.SetRecvError(true) + + remoteAddr := listenerAddr + remoteAddr.Addr = test.remoteAddr + if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + } + + // Wait for an error due to link resolution failing, or the endpoint to be + // writable. + <-ch + var wOpts tcpip.WriteOptions + if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { + t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + } + + if test.expectedWriteErr == nil { + return + } + + sockErr := sockOpts.DequeueErr() + if sockErr == nil { + t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil") + } + + sockErrCmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported(tcpip.SockError{}), + cmp.Comparer(func(a, b *tcpip.Error) bool { + // tcpip.Error holds an unexported field but the errors netstack uses + // are pre defined so we can simply compare pointers. + return a == b + }), + // Ignore the payload since we do not know the TCP seq/ack numbers. + checker.IgnoreCmpPath( + "Payload", + ), + } + + if addr, err := clientEP.GetLocalAddress(); err != nil { + t.Fatalf("clientEP.GetLocalAddress(): %s", err) + } else { + test.sockError.Offender.Port = addr.Port + } + if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index cb6169cfc..a59f25cc3 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -232,12 +232,12 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, _, err := sep.Write(tcpip.SlicePayload(data), wopts) + n, err := sep.Write(tcpip.SlicePayload(data), wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } if want := int64(len(data)); n != want { - t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) + t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) } var buf bytes.Buffer diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index b42375695..eabc87938 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -587,10 +587,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, } data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, _, err := wep.ep.Write(data, writeOpts); err != nil { + if n, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { - t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) } for j, rep := range eps { diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 52cf89b54..76f7f54c6 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -197,10 +197,10 @@ func TestLocalPing(t *testing.T) { payload := tcpip.SlicePayload(test.icmpBuf(t)) var wOpts tcpip.WriteOptions - if n, _, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(payload, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) } // Wait for the endpoint to become readable. @@ -335,14 +335,14 @@ func TestLocalUDP(t *testing.T) { wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { - t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the // UDP packet. return } else if n != int64(len(clientPayload)) { - t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload)) } } @@ -382,10 +382,10 @@ func TestLocalUDP(t *testing.T) { wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, _, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(serverPayload, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { - t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c32fe5c4f..87277fbd3 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -236,8 +236,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - n, ch, err := e.write(p, opts) +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -247,8 +247,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -256,13 +254,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } to := opts.To @@ -272,14 +270,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } // Prepare for write. for { retry, err := e.prepareForWrite(to) if err != nil { - return 0, nil, err + return 0, err } if !retry { @@ -294,7 +292,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } nicID = e.BindNICID @@ -302,31 +300,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { - return 0, nil, err + return 0, err } // Find the endpoint. r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { - return 0, nil, err + return 0, err } defer r.Release() route = r } - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } - } - v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } switch e.NetProto { @@ -338,10 +327,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // SetSockOpt sets a socket option. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3ab060751..c3b3b8d34 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,9 +207,9 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index dd260535f..425bcf3ee 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -234,20 +234,20 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } } - n, ch, err := e.write(p, opts) + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -257,8 +257,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -266,25 +264,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, nil, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } // If this is an unassociated socket and callee provided a nonzero @@ -292,7 +290,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -313,7 +311,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, nil, tcpip.ErrDestinationRequired + return 0, tcpip.ErrDestinationRequired } return e.finishWrite(payloadBytes, e.route) @@ -323,42 +321,30 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - return 0, nil, err + return 0, err } - n, ch, err := e.finishWrite(payloadBytes, route) + n, err := e.finishWrite(payloadBytes, route) route.Release() - return n, ch, err + return n, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { - // We may need to resolve the route (match a link layer address to the - // network address). If that requires blocking (e.g. to use ARP), - // return a channel on which the caller can wait. - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } - } - +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) if err := route.WriteHeaderIncludedPacket(pkt); err != nil { - return 0, nil, err + return 0, err } } else { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -371,11 +357,11 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, pkt); err != nil { - return 0, nil, err + return 0, err } } - return int64(len(payloadBytes)), nil, nil + return int64(len(payloadBytes)), nil } // Disconnect implements tcpip.Endpoint.Disconnect. diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2d96a65bd..9e8872fc9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -210,7 +210,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i if err != nil { return nil, err } - route.ResolveWith(s.remoteLinkAddr) n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) @@ -573,7 +572,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er return err } defer route.Release() - route.ResolveWith(s.remoteLinkAddr) // Send SYN without window scaling because we currently // don't encode this information in the cookie. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a00ef97c6..f45d26a87 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -460,66 +460,9 @@ func (h *handshake) processSegments() *tcpip.Error { return nil } -func (h *handshake) resolveRoute() *tcpip.Error { - // Set up the wakers. - var s sleep.Sleeper - resolutionWaker := &sleep.Waker{} - s.AddWaker(resolutionWaker, wakerForResolution) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - defer s.Done() - - // Initial action is to resolve route. - index := wakerForResolution - attemptedResolution := false - for { - switch index { - case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { - if err != nil { - h.ep.stats.SendErrors.NoRoute.Increment() - } - // Either success (err == nil) or failure. - return err - } - if attemptedResolution { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - return tcpip.ErrNoLinkAddress - } - attemptedResolution = true - // Resolution not completed. Keep trying... - - case wakerForNotification: - n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { - return tcpip.ErrAborted - } - if n¬ifyDrain != 0 { - close(h.ep.drainDone) - h.ep.mu.Unlock() - <-h.ep.undrain - h.ep.mu.Lock() - } - if n¬ifyError != 0 { - return h.ep.lastErrorLocked() - } - } - - // Wait for notification. - h.ep.mu.Unlock() - index, _ = s.Fetch(true /* block */) - h.ep.mu.Lock() - } -} - // start resolves the route if necessary and sends the first // SYN/SYN-ACK. func (h *handshake) start() *tcpip.Error { - if h.ep.route.IsResolutionRequired() { - if err := h.resolveRoute(); err != nil { - return err - } - } - h.startTime = time.Now() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 25b180fa5..ddbed7e46 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1507,7 +1507,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -1520,7 +1520,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err + return 0, err } // We can release locks while copying data. @@ -1541,7 +1541,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() } - return 0, nil, perr + return 0, perr } if !opts.Atomic { @@ -1555,7 +1555,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.sndBufMu.Unlock() e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err + return 0, err } // Discard any excess data copied in due to avail being reduced due @@ -1575,7 +1575,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Do the work inline. e.handleWrite() e.UnlockUser() - return int64(len(v)), nil, nil + return int64(len(v)), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -2779,6 +2779,9 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt case stack.ControlNoRoute: e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + case stack.ControlAddressUnreachable: + e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) + case stack.ControlNetworkUnreachable: e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c9e194f82..1720370c9 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -222,7 +222,6 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error return err } defer route.Release() - route.ResolveWith(s.remoteLinkAddr) // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c5a6d2fba..7cca4def5 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -49,11 +49,10 @@ type segment struct { // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of // individual members for link/network packet info. - srcAddr tcpip.Address - dstAddr tcpip.Address - netProto tcpip.NetworkProtocolNumber - nicID tcpip.NICID - remoteLinkAddr tcpip.LinkAddress + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -89,13 +88,12 @@ type segment struct { func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - srcAddr: netHdr.SourceAddress(), - dstAddr: netHdr.DestinationAddress(), - netProto: pkt.NetworkProtocolNumber, - nicID: pkt.NICID, - remoteLinkAddr: pkt.SourceLinkAddress(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -128,7 +126,6 @@ func (s *segment) clone() *segment { window: s.window, netProto: s.netProto, nicID: s.nicID, - remoteLinkAddr: s.remoteLinkAddr, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index b9993ce1a..f7aaee23f 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -49,7 +49,7 @@ func TestFastRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -214,7 +214,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -256,7 +256,7 @@ func TestCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -361,7 +361,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -470,11 +470,11 @@ func TestRetransmit(t *testing.T) { // Write all the data in two shots. Packets will only be written at the // MTU size though. half := data[:len(data)/2] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } half = data[len(data)/2:] - if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 9818ffa0f..342eb5eb8 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -68,7 +68,7 @@ func TestRACKUpdate(t *testing.T) { // Write the data. xmitTime = time.Now() - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -120,7 +120,7 @@ func TestRACKDetectReorder(t *testing.T) { } // Write the data. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -151,7 +151,7 @@ func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.Vie } // Write the data. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index faf0c0ad7..6635bb815 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -402,7 +402,7 @@ func TestSACKRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index aeceee7e0..729bf7ef5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1348,7 +1348,7 @@ func TestTOSV4(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1397,7 +1397,7 @@ func TestTrafficClassV6(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1977,7 +1977,11 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality // checks on the incoming receive window. - payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadSize := 1 << c.RcvdWindowScale + if payloadSize >= tcp.SegSize { + t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) + } + payload := generateRandomPayload(t, payloadSize) payloadLen := seqnum.Size(len(payload)) iss := seqnum.Value(789) seqNum := iss.Add(1) @@ -2173,7 +2177,7 @@ func TestSimpleSend(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2214,8 +2218,7 @@ func TestZeroWindowSend(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2283,7 +2286,7 @@ func TestScaledWindowConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2315,7 +2318,7 @@ func TestNonScaledWindowConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2389,7 +2392,7 @@ func TestScaledWindowAccept(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2463,7 +2466,7 @@ func TestNonScaledWindowAccept(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2626,7 +2629,7 @@ func TestSegmentMerging(t *testing.T) { // anymore packets from going out. for i := 0; i < tcp.InitialCwnd; i++ { view := buffer.NewViewFromBytes([]byte{0}) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2637,7 +2640,7 @@ func TestSegmentMerging(t *testing.T) { for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2707,7 +2710,7 @@ func TestDelay(t *testing.T) { for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2754,7 +2757,7 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2838,7 +2841,7 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { view := buffer.NewViewFromBytes(data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2889,7 +2892,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3321,7 +3324,7 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) } } @@ -3344,7 +3347,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3401,7 +3404,7 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3450,7 +3453,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3588,7 +3591,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Write something out, and have it acknowledged. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3661,7 +3664,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // any of them. view := buffer.NewView(10) for i := tcp.InitialCwnd; i > 0; i-- { - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3747,7 +3750,7 @@ func TestFinWithPendingData(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3773,7 +3776,7 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3834,7 +3837,7 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. view := buffer.NewView(10) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3871,7 +3874,7 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3978,7 +3981,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Send some data. Check that it's capped by the window size. view := buffer.NewView(65535) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4035,9 +4038,6 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) } - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0) - } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { @@ -4607,7 +4607,7 @@ func TestSelfConnect(t *testing.T) { data := []byte{1, 2, 3} view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4785,7 +4785,7 @@ func TestPathMTUDiscovery(t *testing.T) { data[i] = byte(i) } - if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5074,7 +5074,7 @@ func TestKeepalive(t *testing.T) { // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5903,9 +5903,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) - - if err != nil { + if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7103,7 +7101,7 @@ func TestTCPCloseWithData(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7202,7 +7200,7 @@ func TestTCPUserTimeout(t *testing.T) { // Send some data and wait before ACKing it. view := buffer.NewView(3) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 9e02d467d..88fb054bb 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -154,7 +154,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -217,7 +217,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd view := buffer.NewView(len(data)) copy(view, data) - if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5d87f3a7e..520a0ac9d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -417,8 +417,8 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - n, ch, err := e.write(p, opts) +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { + n, err := e.write(p, opts) switch err { case nil: e.stats.PacketsSent.Increment() @@ -428,8 +428,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c e.stats.WriteErrors.WriteClosed.Increment() case tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoLinkAddress: - e.stats.SendErrors.NoLinkAddr.Increment() case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() @@ -437,17 +435,17 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // For all other errors when writing to the network layer. e.stats.SendErrors.SendToNetworkFailed.Increment() } - return n, ch, err + return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { if err := e.LastError(); err != nil { - return 0, nil, err + return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, nil, tcpip.ErrInvalidOptionValue + return 0, tcpip.ErrInvalidOptionValue } to := opts.To @@ -463,14 +461,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } // Prepare for write. for { retry, err := e.prepareForWrite(to) if err != nil { - return 0, nil, err + return 0, err } if !retry { @@ -486,7 +484,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, nil, tcpip.ErrNoRoute + return 0, tcpip.ErrNoRoute } nicID = e.BindNICID @@ -494,17 +492,17 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, nil, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { - return 0, nil, err + return 0, err } r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { - return 0, nil, err + return 0, err } defer r.Release() @@ -513,21 +511,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, nil, tcpip.ErrBroadcastDisabled - } - - if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - return 0, ch, tcpip.ErrNoLinkAddress - } - return 0, nil, err - } + return 0, tcpip.ErrBroadcastDisabled } v, err := p.FullPayload() if err != nil { - return 0, nil, err + return 0, err } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -545,7 +534,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c v, ) } - return 0, nil, tcpip.ErrMessageTooLong + return 0, tcpip.ErrMessageTooLong } ttl := e.ttl @@ -575,9 +564,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { - return 0, nil, err + return 0, err } - return int64(len(v)), nil, nil + return int64(len(v)), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d7fc21f11..49e673d58 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -75,7 +75,6 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, if err != nil { return nil, err } - route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c8da173f1..52403ed78 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -967,7 +967,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) payload := buffer.View(newPayload()) - _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1008,7 +1008,7 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View } } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1184,7 +1184,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -2317,8 +2317,6 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE want.WriteErrors.WriteClosed.IncrementBy(incr) case tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoLinkAddress: - want.SendErrors.NoLinkAddr.IncrementBy(incr) case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: @@ -2510,20 +2508,20 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { expectedErrWithoutBcastOpt = nil } - if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } ep.SocketOptions().SetBroadcast(true) - if n, _, err := ep.Write(data, opts); err != nil { - t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + if n, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } }) } |