diff options
Diffstat (limited to 'pkg/tcpip/link')
-rw-r--r-- | pkg/tcpip/link/fdbased/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/endpoint.go | 166 | ||||
-rw-r--r-- | pkg/tcpip/link/fdbased/endpoint_test.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/rawfile_unsafe.go | 31 |
4 files changed, 190 insertions, 32 deletions
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 94391433c..a4aa3feec 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -27,6 +27,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3a79d13d4..87c8ab1fc 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -24,6 +24,7 @@ package fdbased import ( + "fmt" "syscall" "gvisor.googlesource.com/gvisor/pkg/tcpip" @@ -33,9 +34,19 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" ) +const ( + // MaxMsgsPerRecv is the maximum number of packets we want to retrieve + // in a single RecvMMsg call. + MaxMsgsPerRecv = 8 +) + // BufConfig defines the shape of the vectorised view used to read packets from the NIC. var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} +// linkDispatcher reads packets from the link FD and dispatches them to the +// NetworkDispatcher. +type linkDispatcher func() (bool, *tcpip.Error) + type endpoint struct { // fd is the file descriptor used to send and receive packets. fd int @@ -57,14 +68,25 @@ type endpoint struct { // its end of the communication pipe. closed func(*tcpip.Error) - iovecs []syscall.Iovec - views []buffer.View - dispatcher stack.NetworkDispatcher + views [][]buffer.View + iovecs [][]syscall.Iovec + msgHdrs []rawfile.MMsgHdr + inboundDispatcher linkDispatcher + dispatcher stack.NetworkDispatcher // handleLocal indicates whether packets destined to itself should be // handled by the netstack internally (true) or be forwarded to the FD // endpoint (false). handleLocal bool + + // useRecvMMsg enables use of recvmmsg() syscall instead of readv() to + // read inbound packets. This reduces # of syscalls needed to process + // packets. + // + // NOTE: recvmmsg() is only supported for sockets, so if the underlying + // FD is not a socket then the code will still fall back to the readv() + // path. + useRecvMMsg bool } // Options specify the details about the fd-based endpoint to be created. @@ -78,6 +100,7 @@ type Options struct { SaveRestore bool DisconnectOk bool HandleLocal bool + UseRecvMMsg bool } // New creates a new fd-based endpoint. @@ -85,7 +108,10 @@ type Options struct { // Makes fd non-blocking, but does not take ownership of fd, which must remain // open for the lifetime of the returned endpoint. func New(opts *Options) tcpip.LinkEndpointID { - syscall.SetNonblock(opts.FD, true) + if err := syscall.SetNonblock(opts.FD, true); err != nil { + // TODO : replace panic with an error return. + panic(fmt.Sprintf("syscall.SetNonblock(%v) failed: %v", opts.FD, err)) + } caps := stack.LinkEndpointCapabilities(0) if opts.ChecksumOffload { @@ -113,13 +139,44 @@ func New(opts *Options) tcpip.LinkEndpointID { closed: opts.ClosedFunc, addr: opts.Address, hdrSize: hdrSize, - views: make([]buffer.View, len(BufConfig)), - iovecs: make([]syscall.Iovec, len(BufConfig)), handleLocal: opts.HandleLocal, + useRecvMMsg: opts.UseRecvMMsg, + } + // For non-socket FDs we read one packet a time (e.g. TAP devices) + msgsPerRecv := 1 + e.inboundDispatcher = e.dispatch + // If the provided FD is a socket then we optimize packet reads by + // using recvmmsg() instead of read() to read packets in a batch. + if isSocketFD(opts.FD) && e.useRecvMMsg { + e.inboundDispatcher = e.recvMMsgDispatch + msgsPerRecv = MaxMsgsPerRecv + } + + e.views = make([][]buffer.View, msgsPerRecv) + for i, _ := range e.views { + e.views[i] = make([]buffer.View, len(BufConfig)) + } + e.iovecs = make([][]syscall.Iovec, msgsPerRecv) + for i, _ := range e.iovecs { + e.iovecs[i] = make([]syscall.Iovec, len(BufConfig)) + } + e.msgHdrs = make([]rawfile.MMsgHdr, msgsPerRecv) + for i, _ := range e.msgHdrs { + e.msgHdrs[i].Msg.Iov = &e.iovecs[i][0] + e.msgHdrs[i].Msg.Iovlen = uint64(len(BufConfig)) } return stack.RegisterLinkEndpoint(e) } +func isSocketFD(fd int) bool { + var stat syscall.Stat_t + if err := syscall.Fstat(fd, &stat); err != nil { + // TODO : replace panic with an error return. + panic(fmt.Sprintf("syscall.Fstat(%v,...) failed: %v", fd, err)) + } + return (stat.Mode & syscall.S_IFSOCK) == syscall.S_IFSOCK +} + // Attach launches the goroutine that reads packets from the file descriptor and // dispatches them via the provided dispatcher. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -191,12 +248,12 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView()) } -func (e *endpoint) capViews(n int, buffers []int) int { +func (e *endpoint) capViews(k, n int, buffers []int) int { c := 0 for i, s := range buffers { c += s if c >= n { - e.views[i].CapLength(s - (c - n)) + e.views[k][i].CapLength(s - (c - n)) return i + 1 } } @@ -204,24 +261,26 @@ func (e *endpoint) capViews(n int, buffers []int) int { } func (e *endpoint) allocateViews(bufConfig []int) { - for i, v := range e.views { - if v != nil { - break - } - b := buffer.NewView(bufConfig[i]) - e.views[i] = b - e.iovecs[i] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), + for k := 0; k < len(e.views); k++ { + for i := 0; i < len(bufConfig); i++ { + if e.views[k][i] != nil { + break + } + b := buffer.NewView(bufConfig[i]) + e.views[k][i] = b + e.iovecs[k][i] = syscall.Iovec{ + Base: &b[0], + Len: uint64(len(b)), + } } } } // dispatch reads one packet from the file descriptor and dispatches it. -func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) { +func (e *endpoint) dispatch() (bool, *tcpip.Error) { e.allocateViews(BufConfig) - n, err := rawfile.BlockingReadv(e.fd, e.iovecs) + n, err := rawfile.BlockingReadv(e.fd, e.iovecs[0]) if err != nil { return false, err } @@ -235,14 +294,14 @@ func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) { remote, local tcpip.LinkAddress ) if e.hdrSize > 0 { - eth := header.Ethernet(e.views[0]) + eth := header.Ethernet(e.views[0][0]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(e.views[0]) { + switch header.IPVersion(e.views[0][0]) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: @@ -252,15 +311,71 @@ func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) { } } - used := e.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, e.views[:used]) + used := e.capViews(0, n, BufConfig) + vv := buffer.NewVectorisedView(n, e.views[0][:used]) vv.TrimFront(e.hdrSize) e.dispatcher.DeliverNetworkPacket(e, remote, local, p, vv) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { - e.views[i] = nil + e.views[0][i] = nil + } + + return true, nil +} + +// recvMMsgDispatch reads more than one packet at a time from the file +// descriptor and dispatches it. +func (e *endpoint) recvMMsgDispatch() (bool, *tcpip.Error) { + e.allocateViews(BufConfig) + + nMsgs, err := rawfile.BlockingRecvMMsg(e.fd, e.msgHdrs) + if err != nil { + return false, err + } + // Process each of received packets. + for k := 0; k < nMsgs; k++ { + n := e.msgHdrs[k].Len + if n <= uint32(e.hdrSize) { + return false, nil + } + + var ( + p tcpip.NetworkProtocolNumber + remote, local tcpip.LinkAddress + ) + if e.hdrSize > 0 { + eth := header.Ethernet(e.views[k][0]) + p = eth.Type() + remote = eth.SourceAddress() + local = eth.DestinationAddress() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(e.views[k][0]) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + used := e.capViews(k, int(n), BufConfig) + vv := buffer.NewVectorisedView(int(n), e.views[k][:used]) + vv.TrimFront(e.hdrSize) + e.dispatcher.DeliverNetworkPacket(e, remote, local, p, vv) + + // Prepare e.views for another packet: release used views. + for i := 0; i < used; i++ { + e.views[k][i] = nil + } + } + + for k := 0; k < nMsgs; k++ { + e.msgHdrs[k].Len = 0 } return true, nil @@ -269,9 +384,8 @@ func (e *endpoint) dispatch(largeV buffer.View) (bool, *tcpip.Error) { // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. func (e *endpoint) dispatchLoop() *tcpip.Error { - v := buffer.NewView(header.MaxIPPacketSize) for { - cont, err := e.dispatch(v) + cont, err := e.inboundDispatcher() if err != nil || !cont { if e.closed != nil { e.closed(err) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 226639443..14abacdf2 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -28,6 +28,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" ) @@ -309,9 +310,22 @@ func TestBufConfigFirst(t *testing.T) { func build(bufConfig []int) *endpoint { e := &endpoint{ - views: make([]buffer.View, len(bufConfig)), - iovecs: make([]syscall.Iovec, len(bufConfig)), + views: make([][]buffer.View, MaxMsgsPerRecv), + iovecs: make([][]syscall.Iovec, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } + + for i, _ := range e.views { + e.views[i] = make([]buffer.View, len(bufConfig)) + } + for i := range e.iovecs { + e.iovecs[i] = make([]syscall.Iovec, len(bufConfig)) + } + for k, msgHdr := range e.msgHdrs { + msgHdr.Msg.Iov = &e.iovecs[k][0] + msgHdr.Msg.Iovlen = uint64(len(bufConfig)) + } + e.allocateViews(bufConfig) return e } @@ -356,12 +370,12 @@ var capLengthTestCases = []struct { func TestCapLength(t *testing.T) { for _, c := range capLengthTestCases { e := build(c.config) - used := e.capViews(c.n, c.config) + used := e.capViews(0, c.n, c.config) if used != c.wantUsed { t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) } - lengths := make([]int, len(e.views)) - for i, v := range e.views { + lengths := make([]int, len(e.views[0])) + for i, v := range e.views[0] { lengths[i] = len(v) } if !reflect.DeepEqual(lengths, c.wantLengths) { diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index be4a4fa9c..5deea093a 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -124,7 +124,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { // BlockingReadv reads from a file descriptor that is set up as non-blocking and // stores the data in a list of iovecs buffers. If no data is available, it will -// block in a poll() syscall until the file descirptor becomes readable. +// block in a poll() syscall until the file descriptor becomes readable. func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) @@ -143,3 +143,32 @@ func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { } } } + +// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. +type MMsgHdr struct { + Msg syscall.Msghdr + Len uint32 + _ [4]byte +} + +// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking +// and stores the received messages in a slice of MMsgHdr structures. If no data +// is available, it will block in a poll() syscall until the file descriptor +// becomes readable. +func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { + for { + n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) + if e == 0 { + return int(n), nil + } + + event := pollEvent{ + fd: int32(fd), + events: 1, // POLLIN + } + + if _, e := blockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR { + return 0, TranslateErrno(e) + } + } +} |