diff options
-rw-r--r-- | pkg/tcpip/link/pipe/pipe.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 37 | ||||
-rw-r--r-- | pkg/tcpip/stack/pending_packets.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 125 |
4 files changed, 222 insertions, 25 deletions
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index d6e83a414..36aa9055c 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -45,12 +45,7 @@ type Endpoint struct { linkAddr tcpip.LinkAddress } -// WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - if !e.linked.IsAttached() { - return nil - } - +func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) { // Note that the local address from the perspective of this endpoint is the // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this @@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw // // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send // and receive queues. - go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + go func() { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + } + }() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.linked.IsAttached() { + var pkts stack.PacketBufferList + pkts.PushBack(pkt) + e.deliverPackets(r, proto, pkts) + } return nil } // WritePackets implements stack.LinkEndpoint. -func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if e.linked.IsAttached() { + e.deliverPackets(r, proto, pkts) + } + + return pkts.Len(), nil } // Attach implements stack.LinkEndpoint. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index d19150a20..447b5c99d 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -358,16 +358,43 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN // WritePackets implements NetworkLinkEndpoint. func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - routeInfo := r.Fields() + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + r.Acquire() + n.linkResQueue.enqueue(ch, r, protocol, &pkts) + return pkts.Len(), nil + } + return 0, err + } + + return n.writePackets(r.Fields(), gso, protocol, pkts) +} + +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - pkt.EgressRoute = routeInfo + pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol } - writtenPackets, err := n.LinkEndpoint.WritePackets(routeInfo, gso, pkts, protocol) + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 81d8ff6e8..3ac039c7d 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,10 +28,26 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { route *Route proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. @@ -54,16 +70,17 @@ func (f *packetsPendingLinkResolution) init() { f.packets = make(map[<-chan struct{}][]pendingPacket) } -func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber) { - r.Stats().IP.OutgoingPacketErrors.Increment() +func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(n) // ok may be false if the endpoint's stats do not collect IP-related data. if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { - ipEndpointStats.IPStats().OutgoingPacketErrors.Increment() + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { f.Lock() defer f.Unlock() @@ -73,7 +90,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro packets[0] = pendingPacket{} packets = packets[1:] - incrementOutgoingPacketErrors(r, proto) + incrementOutgoingPacketErrors(r, proto, p.pkt) p.route.Release() } @@ -113,13 +130,29 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled || p.route.IsResolutionRequired() { - incrementOutgoingPacketErrors(r, proto) + incrementOutgoingPacketErrors(r, proto, p.pkt) if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } } } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) + switch pkt := p.pkt.(type) { + case *PacketBuffer: + p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, pkt) + case *PacketBufferList: + p.route.outgoingNIC.writePackets(p.route.Fields(), nil /* gso */, p.proto, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } } p.route.Release() } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index af32d3009..7afc4702c 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -456,3 +458,126 @@ func TestGetLinkAddress(t *testing.T) { }) } } + +func TestWritePacketsLinkResolution(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr *tcpip.Error + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) + } + defer serverEP.Close() + + serverAddr := tcpip.FullAddress{Port: 1234} + if err := serverEP.Bind(serverAddr); err != nil { + t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) + } + + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + data := []byte{1, 2} + var pkts stack.PacketBufferList + for _, d := range data { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: buffer.View([]byte{d}).ToVectorisedView(), + }) + pkt.TransportProtocolNumber = udp.ProtocolNumber + length := uint16(pkt.Size()) + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: serverAddr.Port, + Length: length, + }) + xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + + pkts.PushBack(pkt) + } + + params := stack.NetworkHeaderParams{ + Protocol: udp.ProtocolNumber, + TTL: 64, + TOS: stack.DefaultTOS, + } + + if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { + t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + } else if want := pkts.Len(); want != n { + t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + } + + var writer bytes.Buffer + count := 0 + for { + var rOpts tcpip.ReadOptions + res, err := serverEP.Read(&writer, rOpts) + if err != nil { + if err == tcpip.ErrWouldBlock { + // Should not have anymore bytes to read after we read the sent + // number of bytes. + if count == len(data) { + break + } + + <-serverCH + continue + } + + t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) + } + count += res.Count + } + + if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { + t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if diff := cmp.Diff(data, writer.Bytes()); diff != "" { + t.Errorf("read bytes mismatch (-want +got):\n%s", diff) + } + }) + } +} |