diff options
author | Kevin Krakauer <krakauer@google.com> | 2019-11-06 14:24:38 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-11-06 14:25:59 -0800 |
commit | e1b21f3c8ca989dc94b25526fda1bb107691f1af (patch) | |
tree | 512a212950a3c3b1579582cfb6ce35befd751bc1 /pkg | |
parent | d0d89ceeddd21f1f22e818d78dc3b07d3669dbb5 (diff) |
Use PacketBuffers, rather than VectorisedViews, in netstack.
PacketBuffers are analogous to Linux's sk_buff. They hold all information about
a packet, headers, and payload. This is important for:
* iptables to access various headers of packets
* Preventing the clutter of passing different net and link headers along with
VectorisedViews to packet handling functions.
This change only affects the incoming packet path, and a future change will
change the outgoing path.
Benchmark Regular PacketBufferPtr PacketBufferConcrete
--------------------------------------------------------------------------------
BM_Recvmsg 400.715MB/s 373.676MB/s 396.276MB/s
BM_Sendmsg 361.832MB/s 333.003MB/s 335.571MB/s
BM_Recvfrom 453.336MB/s 393.321MB/s 381.650MB/s
BM_Sendto 378.052MB/s 372.134MB/s 341.342MB/s
BM_SendmsgTCP/0/1k 353.711MB/s 316.216MB/s 322.747MB/s
BM_SendmsgTCP/0/2k 600.681MB/s 588.776MB/s 565.050MB/s
BM_SendmsgTCP/0/4k 995.301MB/s 888.808MB/s 941.888MB/s
BM_SendmsgTCP/0/8k 1.517GB/s 1.274GB/s 1.345GB/s
BM_SendmsgTCP/0/16k 1.872GB/s 1.586GB/s 1.698GB/s
BM_SendmsgTCP/0/32k 1.017GB/s 1.020GB/s 1.133GB/s
BM_SendmsgTCP/0/64k 475.626MB/s 584.587MB/s 627.027MB/s
BM_SendmsgTCP/0/128k 416.371MB/s 503.434MB/s 409.850MB/s
BM_SendmsgTCP/0/256k 323.449MB/s 449.599MB/s 388.852MB/s
BM_SendmsgTCP/0/512k 243.992MB/s 267.676MB/s 314.474MB/s
BM_SendmsgTCP/0/1M 95.138MB/s 95.874MB/s 95.417MB/s
BM_SendmsgTCP/0/2M 96.261MB/s 94.977MB/s 96.005MB/s
BM_SendmsgTCP/0/4M 96.512MB/s 95.978MB/s 95.370MB/s
BM_SendmsgTCP/0/8M 95.603MB/s 95.541MB/s 94.935MB/s
BM_SendmsgTCP/0/16M 94.598MB/s 94.696MB/s 94.521MB/s
BM_SendmsgTCP/0/32M 94.006MB/s 94.671MB/s 94.768MB/s
BM_SendmsgTCP/0/64M 94.133MB/s 94.333MB/s 94.746MB/s
BM_SendmsgTCP/0/128M 93.615MB/s 93.497MB/s 93.573MB/s
BM_SendmsgTCP/0/256M 93.241MB/s 95.100MB/s 93.272MB/s
BM_SendmsgTCP/1/1k 303.644MB/s 316.074MB/s 308.430MB/s
BM_SendmsgTCP/1/2k 537.093MB/s 584.962MB/s 529.020MB/s
BM_SendmsgTCP/1/4k 882.362MB/s 939.087MB/s 892.285MB/s
BM_SendmsgTCP/1/8k 1.272GB/s 1.394GB/s 1.296GB/s
BM_SendmsgTCP/1/16k 1.802GB/s 2.019GB/s 1.830GB/s
BM_SendmsgTCP/1/32k 2.084GB/s 2.173GB/s 2.156GB/s
BM_SendmsgTCP/1/64k 2.515GB/s 2.463GB/s 2.473GB/s
BM_SendmsgTCP/1/128k 2.811GB/s 3.004GB/s 2.946GB/s
BM_SendmsgTCP/1/256k 3.008GB/s 3.159GB/s 3.171GB/s
BM_SendmsgTCP/1/512k 2.980GB/s 3.150GB/s 3.126GB/s
BM_SendmsgTCP/1/1M 2.165GB/s 2.233GB/s 2.163GB/s
BM_SendmsgTCP/1/2M 2.370GB/s 2.219GB/s 2.453GB/s
BM_SendmsgTCP/1/4M 2.005GB/s 2.091GB/s 2.214GB/s
BM_SendmsgTCP/1/8M 2.111GB/s 2.013GB/s 2.109GB/s
BM_SendmsgTCP/1/16M 1.902GB/s 1.868GB/s 1.897GB/s
BM_SendmsgTCP/1/32M 1.655GB/s 1.665GB/s 1.635GB/s
BM_SendmsgTCP/1/64M 1.575GB/s 1.547GB/s 1.575GB/s
BM_SendmsgTCP/1/128M 1.524GB/s 1.584GB/s 1.580GB/s
BM_SendmsgTCP/1/256M 1.579GB/s 1.607GB/s 1.593GB/s
PiperOrigin-RevId: 278940079
Diffstat (limited to 'pkg')
47 files changed, 541 insertions, 329 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3c2b2b5ea..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 14f197a77..22eefb564 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -65,14 +65,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index ae4858529..edef7db26 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -598,8 +598,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e7c05ca4f..7e08e033b 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -43,10 +43,9 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents buffer.VectorisedView - linkHeader buffer.View + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + contents tcpip.PacketBuffer } type context struct { @@ -93,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv, linkHeader} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -317,19 +316,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: buffer.View(b).ToVectorisedView(), - linkHeader: buffer.View(hdr), + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } - // want.contents will be a single view, - // so make pi do the same for the + // want.contents.Data will be a single + // view, so make pi do the same for the // DeepEqual check. - pi.contents = pi.contents.ToView().ToVectorisedView() + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 554d45715..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -190,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 3331b6453..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -293,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index a3b48fa73..bc5d8a2f3 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -80,12 +80,13 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) // Because we're immediately turning around and writing the packet back to the // rx path, we intentionally don't preserve the remote and local link // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + }) return nil } @@ -105,7 +106,10 @@ func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { // There should be an ethernet header at the beginning of packet. linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) packet.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: packet, + LinkHeader: buffer.View(linkHeader), + }) return nil } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 682b60291..9a8e8ebfe 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,8 +80,8 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 279e2b457..2bace5298 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -273,8 +273,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth)) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3e9705c9..199406886 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,13 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: vv.Clone(nil), - linkHeader: linkHeader, + addr: remoteLinkAddr, + proto: proto, + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 39757ea2a..d71a03cd2 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -116,19 +116,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +147,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a04fc1062..b440970e0 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 5f0f8fa2d..df2e70e54 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 46178459e..4161ebf87 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -92,8 +92,8 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..47098bfdc 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -102,7 +102,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 666d8b92a..fe499d47e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -279,7 +279,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -367,7 +369,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -430,13 +434,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -509,7 +517,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -618,12 +628,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..ce771631c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -104,19 +108,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1339f8474..26f1402ed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -198,7 +198,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) @@ -218,19 +218,24 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -301,7 +306,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, tcpip.PacketBuffer{ + Data: payload, + NetworkHeader: buffer.View(ip), + }) } if loop&stack.PacketOut == 0 { return nil @@ -314,22 +322,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if pkt.Data.Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -337,10 +346,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -348,7 +357,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -361,11 +370,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 99f84acd7..f100d84ee 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -464,7 +464,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicid, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 05e8c075b..58f8e80df 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -25,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -52,19 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return @@ -77,7 +77,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := vv + payload := pkt.Data payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() @@ -113,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -123,10 +123,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: @@ -189,9 +189,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) @@ -209,7 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // TODO(tamird/ghanan): there exists an explicit NDP option that is // used to update the neighbor table with link addresses for a @@ -285,13 +285,13 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, hdr, pkt.Data, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -303,7 +303,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d686f79ce..6037a1ef8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -65,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -147,7 +147,9 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -280,7 +282,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pkt.Header, pkt.Payload} size := len(pkt.Header) + len(pkt.Payload) vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } if pkt.Proto != ProtocolNumber { @@ -498,7 +502,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } stats := s.Stats().ICMP.V6PacketsReceived @@ -673,7 +679,9 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } stats := s.Stats().ICMP.V6PacketsReceived @@ -849,9 +857,9 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, - buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, - []buffer.View{hdr.View(), payload})) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5898f8f9e..805d1739c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -97,7 +97,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -108,19 +108,24 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -160,24 +165,25 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vector // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index deaa9b7f3..1cbfa7278 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stats := s.Stats().ICMP.V6PacketsReceived @@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := s.Stats().UDP.PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 69ab7ba12..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -98,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -345,7 +347,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) if test.expectedSuccess { if got := invalid.Value(); got != 0 { diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..10b04239d --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other Views. SR dosen't support this, + // so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +func (pk PacketBuffer) Clone() PacketBuffer { + return PacketBuffer{ + Data: pk.Data.Clone(nil), + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + } +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..04c4cf136 --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,26 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 525a25218..cc789b5af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -328,7 +328,9 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 12969c74e..28a28ae6e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -723,10 +723,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -736,9 +736,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -763,22 +763,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader) + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -806,20 +806,20 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + hdr := buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size())) } } return @@ -833,7 +833,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -845,41 +845,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -890,17 +890,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d7c124e81..5806d294c 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,16 +64,15 @@ type TransportEndpoint interface { UniqueID() uint64 // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. + // this transport endpoint. It sets pkt.TransportHeader. // - // HandlePacket takes ownership of vv. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - // - // HandleControlPacket takes ownership of vv. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) // Close puts the endpoint in a closed state and frees all resources // associated with it. This cleanup may happen asynchronously. Wait can @@ -99,8 +98,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of packet and netHeader. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -117,8 +116,8 @@ type PacketEndpoint interface { // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. // - // HandlePacket takes ownership of packet and linkHeader. - HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) + // HandlePacket takes ownership of pkt. + HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -148,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -166,17 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of vv and netHeader. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. // - // DeliverTransportControlPacket takes ownership of vv. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -248,10 +253,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. + // this network endpoint. It sets pkt.NetworkHeader. // - // HandlePacket takes ownership of vv. - HandlePacket(r *Route, vv buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -294,11 +299,14 @@ type NetworkProtocol interface { // the data link layer. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint - // and hands the packet over for further processing. linkHeader may have - // length 0 when the caller does not have ethernet data. + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. // - // DeliverNetworkPacket takes ownership of vv and linkHeader. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -329,7 +337,9 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -395,7 +405,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8b141cafd..08599d765 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -51,7 +51,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -641,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9dae853d0..1fac5477f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -86,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -138,7 +138,9 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: vv, + }) } if loop&stack.PacketOut == 0 { return nil @@ -259,7 +261,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -269,7 +273,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -279,7 +285,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -288,7 +296,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -298,7 +308,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -373,7 +385,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -1795,7 +1809,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1855,7 +1871,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { case <-ep2.C: diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index ccd3d030e..594570216 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -21,7 +21,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -86,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -100,18 +99,18 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { - mpep.handlePacketAll(r, id, vv) + mpep.handlePacketAll(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -127,7 +126,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -258,18 +257,16 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpointsArr[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy except for - // the final one. + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt) break } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -395,7 +392,7 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -408,8 +405,8 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // transport endpoints. var destEps []*endpointsByNic if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, vv, id) - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { + destEps = d.findAllEndpointsLocked(eps, id) + } else if ep := d.findEndpointLocked(eps, id); ep != nil { destEps = append(destEps, ep) } @@ -424,17 +421,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.handlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -448,7 +447,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -458,7 +457,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -466,7 +465,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -475,12 +474,12 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // Deliver the packet. - ep.handleControlPacket(n, id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -514,8 +513,8 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv bu // findEndpointLocked returns the endpoint that most closely matches the given // id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { return matchedEPs[0] } return nil diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 210233dc0..f54117c4e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -156,7 +156,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 203e79f56..2cacea99a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -197,7 +197,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -214,7 +214,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -271,7 +271,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -342,7 +342,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -351,7 +353,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -360,7 +364,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -413,7 +419,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -422,7 +430,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -431,7 +441,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -584,7 +596,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 33405eb7d..0092d0ea9 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -718,18 +718,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -757,19 +757,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -780,7 +780,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index ead83b83d..26335094e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -266,7 +266,7 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) { +func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -289,9 +289,9 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(b/129292371): Return network protocol. - if len(ethHeader) > 0 { + if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. - hdr := header.Ethernet(ethHeader) + hdr := header.Ethernet(pkt.LinkHeader) packet.senderAddr = tcpip.FullAddress{ NIC: nicid, Addr: tcpip.Address(hdr.SourceAddress()), @@ -306,11 +306,12 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv + packet.data = pkt.Data } else { // Raw packets need their ethernet headers prepended before // queueing. - if len(ethHeader) == 0 { + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -320,10 +321,12 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, } fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) fakeHeader.Encode(ðFields) - ethHeader = buffer.View(fakeHeader) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - combinedVV := buffer.View(ethHeader).ToVectorisedView() - combinedVV.Append(vv) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) packet.data = combinedVV } packet.timestampNS = ep.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 23922a30e..230a1537a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -555,7 +555,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -596,20 +596,21 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &rawPacket{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - pkt.data = combinedVV - pkt.timestampNS = e.stack.NowNanoseconds() + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.NowNanoseconds() - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1efd8d55..e31464c9b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2029,8 +2029,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -2065,7 +2065,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index db40785d3..c4f1a84bb 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -126,8 +126,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c4a89525e..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ef823e4ae..4854e719d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -302,7 +302,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -350,13 +352,17 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) } // SendAck sends an ACK packet. @@ -518,7 +524,9 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 03bd5c8fd..4e11de9db 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1158,17 +1158,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() @@ -1192,18 +1192,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1214,7 +1214,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d399ec722..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 5c3358a5e..43f11b700 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,20 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - // The buffers used by vv and netHeader may be used elsewhere - // in the system. For example, a raw or packet socket may use - // what UDP considers an unreachable destination. Thus we deep - // copy vv and netHeader to prevent multiple ownership and SR - // errors. - newNetHeader := make(buffer.View, len(netHeader)) - copy(newNetHeader, netHeader) - payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader}) - payload.Append(vv.ToView().ToVectorisedView()) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -158,12 +156,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b724d788c..30ee9801b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } // injectV4Packet creates a V4 test packet with the given payload and header @@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { |