diff options
34 files changed, 462 insertions, 141 deletions
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index d37d624fe..6e3ee2e50 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -167,6 +167,12 @@ const ( // TCPMinimumSize is the minimum size of a valid TCP packet. TCPMinimumSize = 20 + // TCPOptionsMaximumSize is the maximum size of TCP options. + TCPOptionsMaximumSize = 40 + + // TCPHeaderMaximumSize is the maximum header size of a TCP packet. + TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize + // TCPProtocolNumber is TCP's transport protocol number. TCPProtocolNumber tcpip.TransportProtocolNumber = 6 ) @@ -291,6 +297,11 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// TCPChecksumOffset returns offset of the checksum field. +func TCPChecksumOffset() uint16 { + return tcpChecksum +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP Header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 25cffa787..8c0d11288 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -109,7 +109,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { p := PacketInfo{ Header: hdr.View(), Proto: protocol, diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index bcf9c023e..50ce91a4e 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -6,6 +6,7 @@ go_library( name = "fdbased", srcs = [ "endpoint.go", + "endpoint_unsafe.go", "mmap.go", "mmap_amd64_unsafe.go", ], diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index d726551b0..20e34c5ee 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -111,6 +111,10 @@ type endpoint struct { // ringOffset is the current offset into the ring buffer where the next // inbound packet will be placed by the kernel. ringOffset int + + // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is + // disabled. + gsoMaxSize uint32 } // Options specify the details about the fd-based endpoint to be created. @@ -123,6 +127,7 @@ type Options struct { Address tcpip.LinkAddress SaveRestore bool DisconnectOk bool + GSOMaxSize uint32 PacketDispatchMode PacketDispatchMode } @@ -165,6 +170,10 @@ func New(opts *Options) tcpip.LinkEndpointID { packetDispatchMode: opts.PacketDispatchMode, } + if opts.GSOMaxSize != 0 && isSocketFD(opts.FD) { + e.caps |= stack.CapabilityGSO + e.gsoMaxSize = opts.GSOMaxSize + } if isSocketFD(opts.FD) && e.packetDispatchMode == PacketMMap { if err := e.setupPacketRXRing(); err != nil { // TODO: replace panic with an error return. @@ -185,17 +194,22 @@ func New(opts *Options) tcpip.LinkEndpointID { } e.views = make([][]buffer.View, msgsPerRecv) - for i, _ := range e.views { + for i := range e.views { e.views[i] = make([]buffer.View, len(BufConfig)) } e.iovecs = make([][]syscall.Iovec, msgsPerRecv) - for i, _ := range e.iovecs { - e.iovecs[i] = make([]syscall.Iovec, len(BufConfig)) + iovLen := len(BufConfig) + if e.Capabilities()&stack.CapabilityGSO != 0 { + // virtioNetHdr is prepended before each packet. + iovLen++ + } + for i := range e.iovecs { + e.iovecs[i] = make([]syscall.Iovec, iovLen) } e.msgHdrs = make([]rawfile.MMsgHdr, msgsPerRecv) - for i, _ := range e.msgHdrs { + for i := range e.msgHdrs { e.msgHdrs[i].Msg.Iov = &e.iovecs[i][0] - e.msgHdrs[i].Msg.Iovlen = uint64(len(BufConfig)) + e.msgHdrs[i].Msg.Iovlen = uint64(iovLen) } return stack.RegisterLinkEndpoint(e) @@ -246,9 +260,27 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } +// virtioNetHdr is declared in linux/virtio_net.h. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +// These constants are declared in linux/virtio_net.h. +const ( + _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 + + _VIRTIO_NET_HDR_GSO_TCPV4 = 1 + _VIRTIO_NET_HDR_GSO_TCPV6 = 4 +) + // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) @@ -266,11 +298,37 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b eth.Encode(ethHdr) } + if e.Capabilities()&stack.CapabilityGSO != 0 { + vnetHdr := virtioNetHdr{} + vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) + if gso != nil { + vnetHdr.hdrLen = uint16(hdr.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS + } + } + + return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView()) + } + if payload.Size() == 0 { return rawfile.NonBlockingWrite(e.fd, hdr.View()) } - return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil) } // WriteRawPacket writes a raw packet directly to the file descriptor. @@ -292,13 +350,25 @@ func (e *endpoint) capViews(k, n int, buffers []int) int { func (e *endpoint) allocateViews(bufConfig []int) { for k := 0; k < len(e.views); k++ { + var vnetHdr [virtioNetHdrSize]byte + vnetHdrOff := 0 + if e.Capabilities()&stack.CapabilityGSO != 0 { + // The kernel adds virtioNetHdr before each packet, but + // we don't use it, so so we allocate a buffer for it, + // add it in iovecs but don't add it in a view. + e.iovecs[k][0] = syscall.Iovec{ + Base: &vnetHdr[0], + Len: uint64(virtioNetHdrSize), + } + vnetHdrOff++ + } for i := 0; i < len(bufConfig); i++ { if e.views[k][i] != nil { break } b := buffer.NewView(bufConfig[i]) e.views[k][i] = b - e.iovecs[k][i] = syscall.Iovec{ + e.iovecs[k][i+vnetHdrOff] = syscall.Iovec{ Base: &b[0], Len: uint64(len(b)), } @@ -314,7 +384,11 @@ func (e *endpoint) dispatch() (bool, *tcpip.Error) { if err != nil { return false, err } - + if e.Capabilities()&stack.CapabilityGSO != 0 { + // Skip virtioNetHdr which is added before each packet, it + // isn't used and it isn't in a view. + n -= virtioNetHdrSize + } if n <= e.hdrSize { return false, nil } @@ -366,8 +440,11 @@ func (e *endpoint) recvMMsgDispatch() (bool, *tcpip.Error) { } // Process each of received packets. for k := 0; k < nMsgs; k++ { - n := e.msgHdrs[k].Len - if n <= uint32(e.hdrSize) { + n := int(e.msgHdrs[k].Len) + if e.Capabilities()&stack.CapabilityGSO != 0 { + n -= virtioNetHdrSize + } + if n <= e.hdrSize { return false, nil } @@ -425,6 +502,11 @@ func (e *endpoint) dispatchLoop() *tcpip.Error { } } +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + return e.gsoMaxSize +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 14abacdf2..ecc5b73f3 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -24,6 +24,7 @@ import ( "syscall" "testing" "time" + "unsafe" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -33,10 +34,12 @@ import ( ) const ( - mtu = 1500 - laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") - raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") - proto = 10 + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 + csumOffset = 48 + gsoMSS = 500 ) type packetInfo struct { @@ -130,67 +133,108 @@ func TestAddress(t *testing.T) { } } -func TestWritePacket(t *testing.T) { - lengths := []int{0, 100, 1000} - eths := []bool{true, false} +func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { + c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) + defer c.cleanup() - for _, eth := range eths { - for _, plen := range lengths { - t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { - c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) - defer c.cleanup() + r := &stack.Route{ + RemoteLinkAddress: raddr, + } - r := &stack.Route{ - RemoteLinkAddress: raddr, - } + // Build header. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) + b := hdr.Prepend(100) + for i := range b { + b[i] = uint8(rand.Intn(256)) + } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) - } + // Build payload and write. + payload := make(buffer.View, plen) + for i := range payload { + payload[i] = uint8(rand.Intn(256)) + } + want := append(hdr.View(), payload...) + var gso *stack.GSO + if gsoMaxSize != 0 { + gso = &stack.GSO{ + Type: stack.GSOTCPv6, + NeedsCsum: true, + CsumOffset: csumOffset, + MSS: gsoMSS, + MaxSize: gsoMaxSize, + } + } + if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) - } - want := append(hdr.View(), payload...) - if err := c.ep.WritePacket(r, hdr, payload.ToVectorisedView(), proto); err != nil { - t.Fatalf("WritePacket failed: %v", err) - } + // Read from fd, then compare with what we wrote. + b = make([]byte, mtu) + n, err := syscall.Read(c.fds[0], b) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + b = b[:n] + if gsoMaxSize != 0 { + vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0])) + if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { + t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) + } + csumStart := header.EthernetMinimumSize + gso.L3HdrLen + if vnetHdr.csumStart != csumStart { + t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) + } + if vnetHdr.csumOffset != csumOffset { + t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) + } + gsoType := uint8(0) + if int(gso.MSS) < plen { + gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + } + if vnetHdr.gsoType != gsoType { + t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType) + } + b = b[virtioNetHdrSize:] + } + if eth { + h := header.Ethernet(b) + b = b[header.EthernetMinimumSize:] - // Read from fd, then compare with what we wrote. - b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - b = b[:n] - if eth { - h := header.Ethernet(b) - b = b[header.EthernetMinimumSize:] + if a := h.SourceAddress(); a != laddr { + t.Fatalf("SourceAddress() = %v, want %v", a, laddr) + } - if a := h.SourceAddress(); a != laddr { - t.Fatalf("SourceAddress() = %v, want %v", a, laddr) - } + if a := h.DestinationAddress(); a != raddr { + t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) + } - if a := h.DestinationAddress(); a != raddr { - t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) - } + if et := h.Type(); et != proto { + t.Fatalf("Type() = %v, want %v", et, proto) + } + } + if len(b) != len(want) { + t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) + } + if !bytes.Equal(b, want) { + t.Fatalf("Read returned %x, want %x", b, want) + } +} - if et := h.Type(); et != proto { - t.Fatalf("Type() = %v, want %v", et, proto) - } - } - if len(b) != len(want) { - t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) - } - if !bytes.Equal(b, want) { - t.Fatalf("Read returned %x, want %x", b, want) - } - }) +func TestWritePacket(t *testing.T) { + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + gsos := []uint32{0, 32768} + + for _, eth := range eths { + for _, plen := range lengths { + for _, gso := range gsos { + t.Run( + fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), + func(t *testing.T) { + testWritePacket(t, plen, eth, gso) + }, + ) + } } } } @@ -210,7 +254,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go new file mode 100644 index 000000000..36e7fe5a9 --- /dev/null +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package fdbased + +import ( + "reflect" + "unsafe" +) + +const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) + +func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + sh.Data = uintptr(unsafe.Pointer(hdr)) + sh.Len = virtioNetHdrSize + sh.Cap = virtioNetHdrSize + return +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index fa54872da..d58c0f885 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -72,7 +72,7 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 29073afae..99edc232d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,9 +87,9 @@ func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, hdr, payload, protocol) + return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) } return tcpip.ErrNoRoute } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index d1d2875cc..7d25effad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, hdr, + endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) buf := make([]byte, 6500) @@ -68,7 +68,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, hdr, + endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 5d36ebe57..fe2779125 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -65,9 +65,9 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a +// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a // single syscall. It fails if partial data is written. -func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { +func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // If the is no second buffer, issue a regular write. if len(b2) == 0 { return NonBlockingWrite(fd, b1) @@ -75,7 +75,7 @@ func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. - iovec := [...]syscall.Iovec{ + iovec := [3]syscall.Iovec{ { Base: &b1[0], Len: uint64(len(b1)), @@ -85,8 +85,15 @@ func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { Len: uint64(len(b2)), }, } + iovecLen := uintptr(2) - _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if len(b3) > 0 { + iovecLen++ + iovec[2].Base = &b3[0] + iovec[2].Len = uint64(len(b3)) + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 27d7eb3b9..6e6aa5a13 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -184,7 +184,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 4b8061b13..1f44e224c 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -272,7 +272,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -341,7 +341,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -395,7 +395,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -410,7 +410,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -435,7 +435,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -455,7 +455,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -493,7 +493,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -509,7 +509,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -534,7 +534,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -547,7 +547,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -555,7 +555,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 4768321d3..462a6e3a3 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -185,10 +185,18 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.lower.LinkAddress() } +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and forwards // the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, hdr.View()) } @@ -229,7 +237,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b panic(err) } } - return e.lower.WritePacket(r, hdr, payload, protocol) + return e.lower.WritePacket(r, gso, hdr, payload, protocol) } func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 39217e49c..bd9f9845b 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -100,12 +100,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, hdr, payload, protocol) e.writeGate.Leave() return err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 6c57e597a..a2df6be95 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { e.writeCount++ return nil } @@ -75,21 +75,21 @@ func TestWaitWrite(t *testing.T) { _, wep := New(stack.RegisterLinkEndpoint(ep)) // Write and check that it goes through. - wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 5ab542f2c..975919e80 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } @@ -103,7 +103,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) + e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -155,7 +155,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7eb0e697d..d79eba4b0 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -145,7 +145,7 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -221,7 +221,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -450,7 +450,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 8f94246c9..a9650de03 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -76,7 +76,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) - r.WritePacket(hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()) case header.ICMPv4EchoReply: if len(v) < header.ICMPv4EchoMinimumSize { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c2f9a1bcf..cbdca98a5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -99,8 +99,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize } +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) @@ -132,7 +140,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } // HandlePacket is called by the link layer when new ipv4 packets arrive for diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index cfc05d9e1..36d98caef 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -118,7 +118,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V defer r.Release() r.LocalAddress = targetAddr pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - r.WritePacket(hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) @@ -143,7 +143,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - r.WritePacket(hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) case header.ICMPv6EchoReply: if len(v) < header.ICMPv6EchoMinimumSize { @@ -202,7 +202,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. DstAddr: r.RemoteAddress, }) - return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index df3b64c98..9a743ea80 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -83,8 +83,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize } +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -107,7 +115,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } // HandlePacket is called by the link layer when new ipv6 packets arrive for diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1d032ebf8..8b6c17a90 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -486,7 +486,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr vv.RemoveFirst() // TODO: use route.WritePacket. - if err := n.linkEP.WritePacket(&r, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index cf4d52fe9..ff356ea22 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -161,7 +161,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -226,6 +226,7 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback + CapabilityGSO ) // LinkEndpoint is the interface implemented by data link layer protocols (e.g., @@ -258,7 +259,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. @@ -381,3 +382,41 @@ func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { return linkEndpoints[id] } + +// GSOType is the type of GSO segments. +// +// +stateify savable +type GSOType int + +// Types of gso segments. +const ( + GSONone GSOType = iota + GSOTCPv4 + GSOTCPv6 +) + +// GSO contains generic segmentation offload properties. +// +// +stateify savable +type GSO struct { + // Type is one of GSONone, GSOTCPv4, etc. + Type GSOType + // NeedsCsum is set if the checksum offload is enabled. + NeedsCsum bool + // CsumOffset is offset after that to place checksum. + CsumOffset uint16 + + // Mss is maximum segment size. + MSS uint16 + // L3Len is L3 (IP) header length. + L3HdrLen uint16 + + // MaxSize is maximum GSO packet size. + MaxSize uint32 +} + +// GSOEndpoint provides access to GSO properties. +type GSOEndpoint interface { + // GSOMaxSize returns the maximum GSO packet size. + GSOMaxSize() uint32 +} diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index ee860eafe..8ae562dcd 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -97,6 +97,14 @@ func (r *Route) Capabilities() LinkEndpointCapabilities { return r.ref.ep.Capabilities() } +// GSOMaxSize returns the maximum GSO packet size. +func (r *Route) GSOMaxSize() uint32 { + if gso, ok := r.ref.ep.(GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -144,8 +152,8 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { - err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl, r.loop) +func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index da8269999..b5375df3c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -112,7 +112,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.linkEP.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -134,7 +134,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable return nil } - return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber) + return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) Close() {} @@ -281,7 +281,7 @@ func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.Vie defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { t.Errorf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 279ab3c56..dfd31557a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -74,7 +74,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) if err != nil { return 0, nil, err } - if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d876005fe..182097b46 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -370,7 +370,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { if e.raw { hdr := buffer.NewPrependable(len(data) + int(r.MaxHeaderLength())) - return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } if len(data) < header.ICMPv4EchoMinimumSize { @@ -395,7 +395,7 @@ func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) } func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { @@ -419,7 +419,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7a19737c7..a3894ed8f 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -214,6 +214,8 @@ func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, ir n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) + n.initGSO() + // Register new endpoint so that packets are routed to it. if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { n.Close() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c4353718e..056e0b09a 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -557,14 +557,14 @@ func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, a } options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) putOptions(options) return err } // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { optLen := len(opts) // Allocate a buffer for the TCP header. hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) @@ -586,12 +586,17 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) + length := uint16(hdr.UsedLength() + data.Size()) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { - length := uint16(hdr.UsedLength() + data.Size()) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + if gso != nil && gso.NeedsCsum { + // This is called CHECKSUM_PARTIAL in the Linux kernel. We + // calculate a checksum of the pseudo-header and save it in the + // TCP header, then the kernel calculate a checksum of the + // header and data and get the right sum of the TCP packet. + tcp.SetChecksum(xsum) + } else if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { xsum = header.ChecksumVV(data, xsum) - tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -600,7 +605,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise r.Stats().TCP.ResetsSent.Increment() } - return r.WritePacket(hdr, data, ProtocolNumber, ttl) + return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) } // makeOptions makes an options slice. @@ -649,7 +654,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options) + err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 5656890f6..0427af34f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "math" "sync" "sync/atomic" @@ -265,6 +266,8 @@ type endpoint struct { // The following are only used to assist the restore run to re-connect. bindAddress tcpip.Address connectingAddress tcpip.Address + + gso *stack.GSO } // StopWork halts packet processing. Only to be used in tests. @@ -1155,6 +1158,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr + e.initGSO() + // Connect in the restore phase does not perform handshake. Restore its // connection setting here. if !handshake { @@ -1698,3 +1703,25 @@ func (e *endpoint) completeState() stack.TCPEndpointState { } return s } + +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityGSO == 0 { + return + } + + gso := &stack.GSO{} + switch e.netProto { + case header.IPv4ProtocolNumber: + gso.Type = stack.GSOTCPv4 + gso.L3HdrLen = header.IPv4MinimumSize + case header.IPv6ProtocolNumber: + gso.Type = stack.GSOTCPv6 + gso.L3HdrLen = header.IPv6MinimumSize + default: + panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + } + gso.NeedsCsum = true + gso.CsumOffset = header.TCPChecksumOffset() + gso.MaxSize = e.route.GSOMaxSize() + e.gso = gso +} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 8a42f8593..230668b5d 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -153,7 +153,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index d751c7d8e..6317748cf 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -129,6 +129,9 @@ type sender struct { // It is initialized on demand. maxPayloadSize int + // gso is set if generic segmentation offload is enabled. + gso bool + // sndWndScale is the number of bits to shift left when reading the send // window size from a segment. sndWndScale uint8 @@ -194,6 +197,11 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. last: iss, }, + gso: ep.gso != nil, + } + + if s.gso { + s.ep.gso.MSS = uint16(maxPayloadSize) } s.cc = s.initCongestionControl(ep.cc) @@ -244,6 +252,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { } s.maxPayloadSize = m + if s.gso { + s.ep.gso.MSS = uint16(m) + } s.outstanding -= count if s.outstanding < 0 { @@ -338,6 +349,15 @@ func (s *sender) resendSegment() { // Resend the segment. if seg := s.writeList.Front(); seg != nil { + if seg.data.Size() > s.maxPayloadSize { + available := s.maxPayloadSize + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(available) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(available) + } s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) s.ep.stack.Stats().TCP.FastRetransmit.Increment() s.ep.stack.Stats().TCP.Retransmits.Increment() @@ -408,11 +428,24 @@ func (s *sender) retransmitTimerExpired() bool { return true } +// pCount returns the number of packets in the segment. Due to GSO, a segment +// can be composed of multiple packets. +func (s *sender) pCount(seg *segment) int { + size := seg.data.Size() + if size == 0 { + return 1 + } + + return (size-1)/s.maxPayloadSize + 1 +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { limit := s.maxPayloadSize - + if s.gso { + limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) + } // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding @@ -427,6 +460,10 @@ func (s *sender) sendData() { end := s.sndUna.Add(s.sndWnd) var dataSent bool for ; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } // We abuse the flags field to determine if we have already // assigned a sequence number to this segment. if seg.flags == 0 { @@ -518,7 +555,7 @@ func (s *sender) sendData() { seg.data.CapLength(available) } - s.outstanding++ + s.outstanding += s.pCount(seg) segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) } @@ -744,8 +781,10 @@ func (s *sender) handleRcvdSegment(seg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { + prevCount := s.pCount(seg) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) + s.outstanding -= prevCount - s.pCount(seg) break } @@ -753,7 +792,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { s.writeNext = seg.Next() } s.writeList.Remove(seg) - s.outstanding-- + s.outstanding -= s.pCount(seg) seg.decRef() ackLeft -= datalen } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index aa2a73829..5cef8ee97 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -699,7 +699,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * synOptions := header.ParseSynOptions(tcpSeg.Options(), false) // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, 40) + synAckOptions := make([]byte, header.TCPOptionsMaximumSize) offset := 0 if wantOptions.TS { offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) @@ -847,7 +847,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - opts := make([]byte, 40) + opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b68ed8561..5637f46e3 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -651,7 +651,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - return r.WritePacket(hdr, data, ProtocolNumber, ttl) + return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { |