diff options
Diffstat (limited to 'pkg/tcpip')
70 files changed, 1560 insertions, 841 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index e57d45f2a..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -22,7 +22,6 @@ go_test( size = "small", srcs = ["gonet_test.go"], library = ":gonet", - tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 76839eb92..62ac932bb 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + // TTL returns the "TTL" field of the ipv4 header. func (b IPv4) TTL() uint8 { return b[ttl] diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 2c4591409..3499d8399 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 { return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) } +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + // IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. // // The IPv6 payload may contain IPv6 extension headers before any upper layer diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5eb78b398..20b183da0 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -181,12 +181,12 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() p := PacketInfo{ - Pkt: &pkt, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 5ee508d48..f34082e1a 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -387,7 +387,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -641,7 +641,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 6f41a71a8..eaee7e5d7 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents stack.PacketBuffer + contents *stack.PacketBuffer } type context struct { @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), Hash: hash, @@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: stack.PacketBuffer{ + contents: &stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index ca4229ed6..2dfd29aa9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 26c96a655..f04738cfb 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } @@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 20d9e95f6..568c6874f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f0769830a..c69d6b7e9 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,7 +80,7 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 87c734c1f..0744f66d6 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index ec5c5048a..b5dfb7850 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,7 +102,7 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() @@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] - if !d.q.enqueue(&pkt) { + if !d.q.enqueue(pkt) { return tcpip.ErrNoBufferSpace } d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index f5dec0a7f..0374a2441 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3fc62607..28a2e88ba 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b060d4627..f2e47b6a7 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -47,13 +47,6 @@ var LogPackets uint32 = 1 // LogPacketsToPCAP must be accessed atomically. var LogPacketsToPCAP uint32 = 1 -var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{ - header.ICMPv4ProtocolNumber: header.IPv4MinimumSize, - header.ICMPv6ProtocolNumber: header.IPv6MinimumSize, - header.UDPProtocolNumber: header.UDPMinimumSize, - header.TCPProtocolNumber: header.TCPMinimumSize, -} - type endpoint struct { dispatcher stack.NetworkDispatcher lower stack.LinkEndpoint @@ -120,8 +113,8 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, &pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -208,8 +201,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, &pkt) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } @@ -287,7 +280,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P vv.TrimFront(header.ARPSize) arp := header.ARP(hdr) log.Infof( - "%s arp %v (%v) -> %v (%v) valid:%v", + "%s arp %s (%s) -> %s (%s) valid:%t", prefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), @@ -299,13 +292,6 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P return } - // We aren't guaranteed to have a transport header - it's possible for - // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { - log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) - return - } - // Figure out the transport layer info. transName := "unknown" srcPort := uint16(0) @@ -346,7 +332,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P icmpType = "info reply" } } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: @@ -381,7 +367,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: @@ -428,7 +414,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P flagsStr[i] = ' ' } } - details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) if flags&header.TCPFlagSyn != 0 { details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) } else { @@ -437,7 +423,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P } default: - log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) return } @@ -445,5 +431,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 617446ea2..6bc9033d0 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index f5a05929f..949b3f2b2 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 0a9b99f18..63bf40562 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..7f27a840d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -94,16 +94,12 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } - h := header.ARP(v) +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.ARP(pkt.NetworkHeader) if !h.IsValid() { return } @@ -122,7 +118,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -177,7 +173,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } @@ -209,6 +205,17 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(header.ARPSize) + return 0, false, true +} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..66e67429c 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index f42abc4bb..2982450f8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t } } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet when all the packets belonging to that ID have been received. func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..7c8fb3e0a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,9 +293,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -378,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: vv, - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -444,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) + pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -487,7 +488,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,9 +535,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -644,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } }) } } + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + if len(v) < netHdrLen { + return &stack.PacketBuffer{Data: v.ToVectorisedView()} + } + return &stack.PacketBuffer{ + NetworkHeader: v[:netHdrLen], + Data: v[netHdrLen:].ToVectorisedView(), + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..1b67aa066 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -56,9 +56,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { received.Invalid.Increment() @@ -88,7 +91,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -102,7 +105,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..7e9f16c90 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,6 +21,7 @@ package ipv4 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -129,7 +130,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -169,7 +170,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -188,7 +189,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -202,7 +203,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -245,7 +246,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,43 +254,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. netHeader := header.IPv4(pkt.NetworkHeader) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + e.HandlePacket(&loopedR, pkt) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -342,23 +329,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { + if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, pkt) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } @@ -370,7 +350,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) @@ -426,35 +406,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv4(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv4(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - pkt.NetworkHeader = headerView[:h.HeaderLength()] - - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if pkt.Data.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -473,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -485,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) + pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -565,6 +533,35 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv4(hdr) + + // If there are options, pull those into hdr as well. + if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(headerLen) + if !ok { + panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) + } + ipHdr = header.IPv4(hdr) + } + + // If this is a fragment, don't bother parsing the transport header. + parseTransportHeader := true + if ipHdr.More() || ipHdr.FragmentOffset() != 0 { + parseTransportHeader = false + } + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return ipHdr.TransportProtocol(), parseTransportHeader, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..11e579c4b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan stack.PacketBuffer + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -184,7 +184,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ + source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []stack.PacketBuffer + var results []*stack.PacketBuffer L: for { select { @@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } @@ -644,6 +652,18 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload1, udpPayload2}, }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -698,7 +718,7 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..2ff7eedf4 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader) // Validate ICMPv6 checksum before processing the packet. // @@ -288,7 +291,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -390,7 +393,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -532,7 +535,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..52a01b44e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +67,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(icmp header.ICMPv6) { + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + ep.HandlePacket(&r, &stack.PacketBuffer{ + NetworkHeader: buffer.View(ip), + Data: buffer.View(icmp).ToVectorisedView(), }) } for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -328,7 +324,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ Data: vv, }) } @@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) } @@ -740,7 +733,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -918,7 +911,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..95fbcf2d1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -163,30 +163,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + h := header.IPv6(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv6(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - - pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] - pkt.Data.TrimFront(header.IPv6MinimumSize) - pkt.Data.CapLength(int(h.PayloadLength())) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false for firstHeader := true; ; firstHeader = false { @@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6FragmentExtHdr: hasFragmentHeader = true - fragmentOffset := extHdr.FragmentOffset() - more := extHdr.More() - if !more && fragmentOffset == 0 { + if extHdr.IsAtomic() { // This fragment extension header indicates that this packet is an // atomic fragment. An atomic fragment is a fragment that contains // all the data required to reassemble a full packet. As per RFC 6946, @@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. - rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) - if fragmentOffset == 0 { + if extHdr.FragmentOffset() == 0 { // Check that the iterator ends with a raw payload as the first fragment // should include all headers up to and including any upper layer // headers, as per RFC 8200 section 4.5; only upper layer data @@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } // The packet is a fragment, let's try to reassemble it. - start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit last := start + uint16(fragmentPayloadLen) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the @@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } var ready bool - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6RawPayloadHeader: // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(len(pkt.TransportHeader)) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt, hasFragmentHeader) + e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error @@ -505,6 +510,79 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + // + // Parsing occurs again in HandlePacket because we don't track the + // extensions in PacketBuffer. Unfortunately, that means HandlePacket + // has to do the parsing work again. + var nextHdr tcpip.TransportProtocolNumber + foundNext := true + extensionsSize := 0 +traverseExtensions: + for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { + if err != nil { + break + } + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + foundNext = false + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + // If this is an atomic fragment, we don't have to treat it specially. + if !extHdr.More() && extHdr.FragmentOffset() == 0 { + continue + } + // This is a non-atomic fragment and has to be re-assembled before we can + // examine the payload for a transport header. + foundNext = false + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader. + hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + + return nextHdr, foundNext, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..213ff64f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..64239ce9a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View if atomicFragment { - bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) - bytes[0] = nextHdr + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, &stack.PacketBuffer{ + NetworkHeader: buffer.View(ip), + Data: payload.ToVectorisedView(), }) } @@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetCode(test.code) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -884,7 +885,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b937cb84b..edc29ad27 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -54,17 +54,27 @@ type Flags struct { LoadBalanced bool } -func (f Flags) bits() reuseFlag { - var rf reuseFlag +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags if f.MostRecent { - rf |= mostRecentFlag + rf |= MostRecentFlag } if f.LoadBalanced { - rf |= loadBalancedFlag + rf |= LoadBalancedFlag } return rf } +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -78,56 +88,88 @@ type PortManager struct { hint uint32 } -type reuseFlag int +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 const ( - mostRecentFlag reuseFlag = 1 << iota - loadBalancedFlag + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. nextFlag - flagMask = nextFlag - 1 + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 ) -type portNode struct { - // refs stores the count for each possible flag combination. +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). refs [nextFlag]int } -func (p portNode) totalRefs() int { +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { var total int - for _, r := range p.refs { + for _, r := range c.refs { total += r } return total } -// flagRefs returns the number of references with all specified flags. -func (p portNode) flagRefs(flags reuseFlag) int { +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { var total int - for i, r := range p.refs { - if reuseFlag(i)&flags == flags { + for i, r := range c.refs { + if BitFlags(i)&flags == flags { total += r } } return total } -// allRefsHave returns if all references have all specified flags. -func (p portNode) allRefsHave(flags reuseFlag) bool { - for i, r := range p.refs { - if reuseFlag(i)&flags == flags && r > 0 { +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { return false } } return true } -// intersectionRefs returns the set of flags shared by all references. -func (p portNode) intersectionRefs() reuseFlag { - intersection := flagMask - for i, r := range p.refs { +// IntersectionRefs returns the set of flags shared by all references. +func (c FlagCounter) IntersectionRefs() BitFlags { + intersection := FlagMask + for i, r := range c.refs { if r > 0 { - intersection &= reuseFlag(i) + intersection &= BitFlags(i) } } return intersection @@ -135,26 +177,26 @@ func (p portNode) intersectionRefs() reuseFlag { // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceNode map[tcpip.NICID]FlagCounter // isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all portNodes. If binding to a specific device, check +// device, check against all FlagCounters. If binding to a specific device, check // against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { - flagBits := flags.bits() + flagBits := flags.Bits() if bindToDevice == 0 { // Trying to binding all devices. if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } - intersection := flagMask + intersection := FlagMask for _, p := range d { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { return true } - intersection := flagMask + intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.intersectionRefs() + intersection = p.IntersectionRefs() if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { return false @@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } - flagBits := flags.bits() + flagBits := flags.Bits() // Reserve port on all network protocols. for _, network := range networks { @@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m[addr] = d } n := d[bindToDevice] - n.refs[flagBits]++ + n.AddRef(flagBits) d[bindToDevice] = n } @@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.bits() + flagBits := flags.Bits() for _, network := range networks { desc := portDescriptor{network, transport, port} @@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp } n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == [nextFlag]int{} { + if n.TotalRefs() == 0 { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f71073207..24f52b735 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -89,6 +89,7 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -110,5 +111,6 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..05bf62788 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" @@ -147,46 +146,8 @@ type ConnTrackTable struct { Seed uint32 } -// parseHeaders sets headers in the packet. -func parseHeaders(pkt *PacketBuffer) { - newPkt := pkt.Clone() - - // Set network header. - hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - netHeader := header.IPv4(hdr) - newPkt.NetworkHeader = hdr - length := int(netHeader.HeaderLength()) - - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - // Set transport header. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) - } - case header.TCPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) - } - } - pkt.NetworkHeader = newPkt.NetworkHeader - pkt.TransportHeader = newPkt.TransportHeader -} - // packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { +func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple netHeader := header.IPv4(pkt.NetworkHeader) @@ -257,15 +218,8 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { // TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other // transport protocols. func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - if hook == Prerouting { - // Headers will not be set in Prerouting. - // TODO(gvisor.dev/issue/170): Change this after parsing headers - // code is added. - parseHeaders(pkt) - } - var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) + tuple, err := packetToTuple(pkt, hook) if err != nil { return nil, dir } diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 344d60baa..a6546cef0 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -33,6 +33,10 @@ const ( // except where another value is explicitly used. It is chosen to match // the MTU of loopback interfaces on linux systems. fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -68,16 +72,9 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { - // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fwdTestNetHeaderLen) - +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -96,13 +93,13 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = f.id.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } @@ -112,7 +109,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { } func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = netHeader + pkt.Data.TrimFront(fwdTestNetHeaderLen) + return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -190,7 +197,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,12 +210,12 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -251,7 +258,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +277,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +287,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -361,8 +368,8 @@ func TestForwardingWithStaticResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -398,8 +405,8 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -429,8 +436,8 @@ func TestForwardingWithNoResolver(t *testing.T) { // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -459,16 +466,16 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. buf := buffer.NewView(30) - buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -509,8 +515,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -554,10 +559,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). + if !ok { + t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { t.Fatalf("got the packet #%d, want = #%d", n, want) } @@ -609,8 +618,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Each packet has a different destination address (3 to // maxPendingResolutions + 7). buf := buffer.NewView(30) - buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 709ede3fa..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -43,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -106,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -158,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -321,7 +351,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 36cc6275d..92e31643e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set network header. - if hook == Prerouting { - parseHeaders(pkt) - } - // Drop the packet if network and transport header are not set. if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index a3bd3e700..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -16,6 +16,7 @@ package stack import ( "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } @@ -245,7 +250,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..e28c23d66 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -467,8 +467,17 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The timer used to send the next router solicitation message. - rtrSolicitTimer *time.Timer + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer *time.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. timer = time.AfterFunc(0, func() { - ndp.nic.mu.RLock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if done { // If we reach this point, it means that the DAD timer fired after // another goroutine already obtained the NIC lock and stopped DAD // before this function obtained the NIC lock. Simply return here and do // nothing further. - ndp.nic.mu.RUnlock() return } @@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } dadDone := remaining == 0 - ndp.nic.mu.RUnlock() var err *tcpip.Error if !dadDone { - err = ndp.sendDADPacket(addr) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() } - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. @@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // addr. // // addr must be a tentative IPv6 address on ndp's NIC. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { +// +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) @@ -750,7 +775,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitTimer != nil { + if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. return } @@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) } + ndp.nic.mu.Unlock() + localAddr := ref.ep.ID().LocalAddress r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() @@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) @@ -1881,7 +1926,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) @@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - if remaining == 0 { - ndp.rtrSolicitTimer = nil - } else if ndp.rtrSolicitTimer != nil { + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but // we still reached this point, then we know the NIC // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. - ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } + ndp.nic.mu.Unlock() }) } @@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitTimer == nil { + if ndp.rtrSolicit.timer == nil { // Nothing to do. return } - ndp.rtrSolicitTimer.Stop() - ndp.rtrSolicitTimer = nil + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil } // initializeTempAddrState initializes state related to temporary SLAAC diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..58f1ebf60 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 05646e5e2..644c0d437 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -457,8 +457,20 @@ type ipv6AddrCandidate struct { // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { @@ -568,11 +580,6 @@ const ( // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous - - // forceSpoofing indicates that the NIC should be assumed to be spoofing, - // regardless of what the NIC's spoofing flag is when getting a NIC's - // referenced network endpoint. - forceSpoofing ) func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - n.mu.RLock() var spoofingOrPromiscuous bool @@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.spoofing case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous - case forceSpoofing: - spoofingOrPromiscuous = true } - if ref, ok := n.mu.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.mu.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t }, }, peb, temporary, static, false) - n.mu.Unlock() return ref } @@ -1153,7 +1161,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1175,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1212,12 +1220,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1229,11 +1246,12 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } @@ -1298,10 +1316,20 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this + // by having lower layers explicity write each header instead of just + // pkt.Header. + + // pkt may have set its NetworkHeader and TransportHeader. If we're + // forwarding, we'll have to copy them into pkt.Header. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) + if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) + } + if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) } // WritePacket takes ownership of pkt, calculate numBytes first. @@ -1318,7 +1346,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1332,13 +1360,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader == nil { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + pkt.TransportHeader = transHeader + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } + + if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1365,7 +1411,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index b01b3f476..31f865260 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,268 @@ package stack import ( + "math" "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ID implements NetworkEndpoint.ID. +func (e *testIPv6Endpoint) ID() *NetworkEndpointID { + return &e.id +} + +// PrefixLen implements NetworkEndpoint.PrefixLen. +func (e *testIPv6Endpoint) PrefixLen() int { + return e.prefixLen +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &testIPv6Endpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. @@ -44,7 +301,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..1b5da6017 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -24,6 +24,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ noCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -82,7 +84,32 @@ type PacketBuffer struct { // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index db89234e8..5cbc946b6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -168,6 +168,11 @@ type TransportProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -180,7 +185,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +194,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,7 +247,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have already // been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and @@ -251,7 +256,7 @@ type NetworkEndpoint interface { // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -266,7 +271,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -313,6 +318,14 @@ type NetworkProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver @@ -327,7 +340,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -389,7 +402,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. It takes ownership of pkts and @@ -431,7 +444,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 3d0e5cc6e..d65f8049e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -148,12 +150,14 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -199,7 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..a2190341c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -778,7 +775,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1020,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID @@ -1400,25 +1404,25 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() s.cleanupEndpoints[ep] = struct{}{} - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CompleteTransportEndpointCleanup removes the endpoint from the cleanup @@ -1741,18 +1745,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..ffef9bc2c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -52,6 +52,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -90,30 +94,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -132,24 +134,19 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) + pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] + pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] + pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + f.HandlePacket(r, pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -163,7 +160,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { @@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(fakeNetHeaderLen) + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -292,8 +300,8 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 3 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -304,8 +312,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 1 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -316,8 +324,8 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = 2 + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +336,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +348,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +370,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +428,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { if err := s.SetPromiscuousMode(nicID, true); err != nil { @@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -2263,7 +2271,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2344,8 +2352,8 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) - buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + buf[dstAddrOffset] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..118b449d5 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,7 +15,6 @@ package stack import ( - "container/heap" "fmt" "math/rand" @@ -23,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" ) type protocolIDs struct { @@ -43,14 +43,14 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { return } delete(eps.endpoints, id) @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -214,23 +214,22 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t demux: d, netProto: netProto, transProto: transProto, - reuse: reusePort, } epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, reusePort) + return multiPortEp.singleRegisterEndpoint(t, flags) } // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. -func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } - if multiPortEp.unregisterEndpoint(t) { + if multiPortEp.unregisterEndpoint(t, flags) { delete(epsByNIC.endpoints, bindToDevice) } return len(epsByNIC.endpoints) == 0 @@ -251,7 +250,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -279,10 +278,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) return err } } @@ -290,35 +289,6 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -type transportEndpointHeap []TransportEndpoint - -var _ heap.Interface = (*transportEndpointHeap)(nil) - -func (h *transportEndpointHeap) Len() int { - return len(*h) -} - -func (h *transportEndpointHeap) Less(i, j int) bool { - return (*h)[i].UniqueID() < (*h)[j].UniqueID() -} - -func (h *transportEndpointHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *transportEndpointHeap) Push(x interface{}) { - *h = append(*h, x.(TransportEndpoint)) -} - -func (h *transportEndpointHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - old[n-1] = nil - *h = old[:n-1] - return x -} - // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -334,9 +304,10 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpoints transportEndpointHeap - // reuse indicates if more than one endpoint is allowed. - reuse bool + // endpoints stores the transport endpoints in the order in which they + // were bound. This is required for UDP SO_REUSEADDR. + endpoints []TransportEndpoint + flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,6 +333,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } + if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + return mpep.endpoints[len(mpep.endpoints)-1] + } + payload := []byte{ byte(id.LocalPort), byte(id.LocalPort >> 8), @@ -379,7 +354,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -401,40 +376,47 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + bits := flags.Bits() + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if !ep.reuse || !reusePort { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { return tcpip.ErrPortInUse } } - heap.Push(&ep.endpoints, t) + ep.endpoints = append(ep.endpoints, t) + ep.flags.AddRef(bits) return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. -func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() defer ep.mu.Unlock() for i, endpoint := range ep.endpoints { if endpoint == t { - heap.Remove(&ep.endpoints, i) + copy(ep.endpoints[i:], ep.endpoints[i+1:]) + ep.endpoints[len(ep.endpoints)-1] = nil + ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] + + ep.flags.DropRef(flags.Bits()) break } } return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { - // TODO(eyalsoha): Why? - reusePort = false + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] @@ -454,15 +436,20 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.endpoints[id] = epsByNIC } - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep, bindToDevice) + eps.unregisterEndpoint(id, ep, flags, bindToDevice) } } } @@ -470,7 +457,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +507,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +531,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..73dada928 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -127,7 +128,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +166,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..7e8b84867 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -83,12 +84,13 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) + hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -153,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -198,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, /* reuse */ - 0, /* bindtoDevice */ + ports.Flags{}, + 0, /* bindtoDevice */ ); err != nil { return err } @@ -215,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -324,6 +326,17 @@ func (*fakeTransportProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeTransportProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) + if !ok { + return false + } + pkt.TransportHeader = hdr + pkt.Data.TrimFront(fakeTransHeaderLen) + return true +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -369,7 +382,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +393,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +404,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +459,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +470,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +481,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +636,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) @@ -642,11 +655,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.NetworkHeader[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.NetworkHeader[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9ce625c17..7e5c79776 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..8ce294002 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,7 +111,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -140,11 +141,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -450,7 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -481,7 +477,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -511,6 +507,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC localPort := uint16(0) switch e.state { + case stateInitial: case stateBound, stateConnected: localPort = e.ID.LocalPort if e.BindNICID == 0 { @@ -611,14 +608,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -743,7 +740,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -805,7 +802,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3c47692b2..74ef6541e 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -124,6 +124,16 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 23158173d..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() @@ -298,7 +293,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee754a5a..a406d815e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { @@ -348,7 +343,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -357,7 +352,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), Owner: e.owner, @@ -584,7 +579,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -632,8 +627,9 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { }, } - networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) - combinedVV := networkHeader.ToVectorisedView() + headers := append(buffer.View(nil), pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV := headers.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f38eb6833..e26f01fae 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,10 +86,6 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - # FIXME(b/68809571) - tags = [ - "flaky", - ], deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e6a23c978..ad197e8db 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -238,13 +239,14 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.mu.Lock() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() return nil, err } n.isRegistered = true + n.registeredReusePort = n.reusePort return n, nil } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e4a06c9e1..7da93dcc4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), Data: data, Hash: tf.txHash, Owner: owner, } - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { tf.ttl = r.DefaultTTL() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 6062ca916..047704c80 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -186,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5ba972f1..6e4d607da 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { @@ -430,6 +465,10 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // registeredReusePort is set if the current endpoint registration was + // done with SO_REUSEPORT enabled. + registeredReusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -986,8 +1025,9 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) @@ -1051,8 +1091,9 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } if e.isPortReserved { @@ -1172,11 +1213,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -2058,10 +2094,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) if err != nil { return err } + e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2093,12 +2130,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { case nil: // Port picking successful. Save the details of // the selected port. e.ID = id e.boundBindToDevice = e.bindToDevice + e.registeredReusePort = e.reusePort return true, nil case tcpip.ErrPortInUse: return false, nil @@ -2296,12 +2334,13 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) + e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) @@ -2462,7 +2501,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2481,7 +2520,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index fc43c11e2..cbb779666 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.EndpointState() { - case StateInitial, StateBound: - // TODO(b/138137272): this enumeration duplicates - // EndpointState.connected. remove it. - case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) @@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() { break } fallthrough - case StateListen, StateConnecting: + case epState == StateListen || epState == StateConnecting: e.drainSegmentLocked() - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } - break } - case StateError, StateClose: + case epState.closed(): for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) @@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state EndpointState) { +func (e *endpoint) loadState(epState EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. // For restore purposes we treat TimeWait like a connected endpoint. - if state.connected() || state == StateTimeWait { + if epState.connected() || epState == StateTimeWait { connectedLoading.Add(1) } - switch state { - case StateListen: + switch { + case epState == StateListen: listenLoading.Add(1) - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState // as the endpoint is still being loaded and the stack reference is not // yet initialized. - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) } // afterLoad is invoked by stateify. @@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - state := e.origEndpointState - switch state { + epState := e.origEndpointState + switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + switch { + case epState.connected(): bind() if len(e.connectingAddress) == 0 { e.connectingAddress = e.ID.RemoteAddress @@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) { closed := e.closed e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) - if state == StateFinWait2 && closed { + if epState == StateFinWait2 && closed { // If the endpoint has been closed then make sure we notify so // that the FIN_WAIT2 timer is started after a restore. e.notifyProtocolGoroutine(notifyClose) } connectedLoading.Done() - case StateListen: + case epState == StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case StateBound: + case epState == StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) { bind() tcpip.AsyncLoading.Done() }() - case StateClose: + case epState == StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) { e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) - case StateError: + case epState == StateError: e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } - } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 704d01c64..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a2a7ddeb..73b8a6782 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "fmt" "runtime" "strings" "time" @@ -206,7 +207,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -217,7 +218,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() @@ -490,6 +491,26 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { return &p.synRcvdCount } +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(offset) + if !ok { + panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) + } + } + + pkt.TransportHeader = hdr + pkt.Data.TrimFront(len(hdr)) + return true +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 074edded6..0280892a8 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -35,6 +35,7 @@ type segment struct { id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -60,13 +61,14 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader) s.rcvdTime = time.Now() return s } @@ -146,12 +148,6 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) - // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -162,16 +158,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -180,22 +172,19 @@ func (s *segment) parse() bool { if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { s.csumValid = true verifyChecksum = false - s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) - s.data.TrimFront(offset) + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 06dc9b7d7..acacb42e4 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -618,6 +618,20 @@ func (s *sender) splitSeg(seg *segment, size int) { nSeg.data.TrimFront(size) nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + seg.data.CapLength(size) } @@ -739,7 +753,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(seg.sequenceNumber.Size(end)) + available := int(s.sndNxt.Size(end)) if available > limit { available = limit } @@ -782,8 +796,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // sent all at once. return false } - if atomic.LoadUint32(&s.ep.cork) != 0 { - // Hold back the segment until full. + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { return false } } @@ -816,6 +833,25 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + // If the entire segment cannot be accomodated in the receiver + // advertized window, skip splitting and sending of the segment. + // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered + // by a probe timer. On this condition, it defers the segment + // split and transmit to a short probe timer. + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split the + // segment right here if there are no pending segments. + // If there are pending segments, segment transmits are deferred + // to the retransmit timer handler. + if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { + return false + } + if !seg.sequenceNumber.LessThan(end) { return false } @@ -824,9 +860,17 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if available == 0 { return false } + + // The segment size limit is computed as a function of sender congestion + // window and MSS. When sender congestion window is > 1, this limit can + // be larger than MSS. Ensure that the currently available send space + // is not greater than minimum of this limit and MSS. if available > limit { available = limit } + if available > s.maxPayloadSize { + available = s.maxPayloadSize + } if seg.data.Size() > available { s.splitSeg(seg, available) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7b1d72cf4..9721f6caf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: s, }) } @@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 647b2067a..df5efbf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,6 +15,7 @@ package udp import ( + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -102,7 +103,7 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool - reusePort bool + portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool @@ -213,7 +214,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -247,11 +248,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -430,24 +426,33 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } var route *stack.Route + var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) var dstPort uint16 if to == nil { route = &e.route dstPort = e.dstPort - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, given that it may need to - // change in Route.Resolve() call below. + resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { + // Promote lock to exclusive if using a shared route, given that it may + // need to change in Route.Resolve() call below. e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() // Recheck state after lock was re-acquired. if e.state != StateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState + err = tcpip.ErrInvalidEndpointState + } + if err == nil && route.IsResolutionRequired() { + ch, err = route.Resolve(waker) } + + e.mu.Unlock() + e.mu.RLock() + + // Recheck state after lock was re-acquired. + if e.state != StateConnected { + err = tcpip.ErrInvalidEndpointState + } + return } } else { // Reject destination address if it goes through a different @@ -478,10 +483,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c route = &r dstPort = dst.Port + resolve = route.Resolve } if route.IsResolutionRequired() { - if ch, err := route.Resolve(nil); err != nil { + if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -552,10 +558,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() case tcpip.ReuseAddressOption: + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() case tcpip.ReusePortOption: e.mu.Lock() - e.reusePort = v + e.portFlags.LoadBalanced = v e.mu.Unlock() case tcpip.V6OnlyOption: @@ -789,11 +798,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil case tcpip.ReuseAddressOption: - return false, nil + e.mu.RLock() + v := e.portFlags.MostRecent + e.mu.RUnlock() + + return v, nil case tcpip.ReusePortOption: e.mu.RLock() - v := e.reusePort + v := e.portFlags.LoadBalanced e.mu.RUnlock() return v, nil @@ -921,7 +934,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -958,6 +975,11 @@ func (e *endpoint) Disconnect() *tcpip.Error { id stack.TransportEndpointID btd tcpip.NICID ) + + // We change this value below and we need the old value to unregister + // the endpoint. + boundPortFlags := e.boundPortFlags + // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -970,16 +992,17 @@ func (e *endpoint) Disconnect() *tcpip.Error { return err } e.state = StateBound + boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -1051,6 +1074,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } + oldPortFlags := e.boundPortFlags + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err @@ -1058,7 +1083,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id @@ -1122,20 +1147,15 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - flags := ports.Flags{ - LoadBalanced: e.reusePort, - // FIXME(b/129164367): Support SO_REUSEADDR. - MostRecent: false, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) if err != nil { return id, e.bindToDevice, err } - e.boundPortFlags = flags id.LocalPort = port } + e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) e.boundPortFlags = ports.Flags{} @@ -1269,18 +1289,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - pkt.Data.TrimFront(header.UDPMinimumSize) - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() @@ -1336,7 +1354,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() defer e.mu.RUnlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a674ceb68..c67e0ba95 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt stack.PacketBuffer + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() return nil, err } @@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort ep.RegisterNICID = r.route.NICID() + ep.boundPortFlags = ep.portFlags ep.state = StateConnected diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..4218e7d03 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,15 +66,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { - // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -121,7 +115,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -130,9 +124,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) - payload := newNetHeader.ToVectorisedView() - payload.Append(pkt.Data.ToView().ToVectorisedView()) + newHeader := append(buffer.View(nil), pkt.NetworkHeader...) + newHeader = append(newHeader, pkt.TransportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -140,9 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) case header.IPv6AddressSize: @@ -164,11 +160,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) payload.Append(pkt.Data) payload.CapLength(payloadLen) @@ -177,9 +173,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) } return true @@ -201,6 +198,18 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Packet is too small + return false + } + pkt.TransportHeader = h + pkt.Data.TrimFront(header.UDPMinimumSize) + return true +} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8acaa607a..313a3f117 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -440,10 +440,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), }) } @@ -487,10 +485,8 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), }) } @@ -1720,6 +1716,58 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { } } +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { |