From 84063e88c382748d6990eb24834c0553d530d517 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 5 Oct 2021 15:39:50 -0700 Subject: Add server implementation for sharedmem endpoints. PiperOrigin-RevId: 401088040 --- pkg/tcpip/link/sharedmem/sharedmem.go | 224 ++++++++++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 26 deletions(-) (limited to 'pkg/tcpip/link/sharedmem/sharedmem.go') diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 66efe6472..e2a8c4863 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -24,6 +24,7 @@ package sharedmem import ( + "fmt" "sync/atomic" "golang.org/x/sys/unix" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -63,16 +65,97 @@ type QueueConfig struct { SharedDataFD int } +// FDs returns the FD's in the QueueConfig as a slice of ints. This must +// be used in conjunction with QueueConfigFromFDs to ensure the order +// of FDs matches when reconstructing the config when serialized or sent +// as part of control messages. +func (q *QueueConfig) FDs() []int { + return []int{q.DataFD, q.EventFD, q.TxPipeFD, q.RxPipeFD, q.SharedDataFD} +} + +// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each +// entry represents an file descriptor. The order of FDs in the slice must be in +// the order specified below for the config to be valid. QueueConfig.FDs() +// should be used when the config needs to be serialized or sent as part of a +// control message to ensure the correct order. +func QueueConfigFromFDs(fds []int) (QueueConfig, error) { + if len(fds) != 5 { + return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds)) + } + return QueueConfig{ + DataFD: fds[0], + EventFD: fds[1], + TxPipeFD: fds[2], + RxPipeFD: fds[3], + SharedDataFD: fds[4], + }, nil +} + +// Options specify the details about the sharedmem endpoint to be created. +type Options struct { + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // BufferSize is the size of each scatter/gather buffer that will hold packet + // data. + // + // NOTE: This directly determines number of packets that can be held in + // the ring buffer at any time. This does not have to be sized to the MTU as + // the shared memory queue design allows usage of more than one buffer to be + // used to make up a given packet. + BufferSize uint32 + + // LinkAddress is the link address for this endpoint (required). + LinkAddress tcpip.LinkAddress + + // TX is the transmit queue configuration for this shared memory endpoint. + TX QueueConfig + + // RX is the receive queue configuration for this shared memory endpoint. + RX QueueConfig + + // PeerFD is the fd for the connected peer which can be used to detect + // peer disconnects. + PeerFD int + + // OnClosed is a function that is called when the endpoint is being closed + // (probably due to peer going away) + OnClosed func(err tcpip.Error) + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool +} + type endpoint struct { // mtu (maximum transmission unit) is the maximum size of a packet. + // mtu is immutable. mtu uint32 // bufferSize is the size of each individual buffer. + // bufferSize is immutable. bufferSize uint32 // addr is the local address of this endpoint. + // addr is immutable. addr tcpip.LinkAddress + // peerFD is an fd to the peer that can be used to detect when the + // peer is gone. + // peerFD is immutable. + peerFD int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // hdrSize is the size of the link layer header if any. + // hdrSize is immutable. + hdrSize uint32 + // rx is the receive queue. rx rx @@ -83,34 +166,55 @@ type endpoint struct { // Wait group used to indicate that all workers have stopped. completed sync.WaitGroup + // onClosed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + onClosed func(tcpip.Error) + // mu protects the following fields. mu sync.Mutex // tx is the transmit queue. + // +checklocks:mu tx tx // workerStarted specifies whether the worker goroutine was started. + // +checklocks:mu workerStarted bool } // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { +func New(opts Options) (stack.LinkEndpoint, error) { e := &endpoint{ - mtu: mtu, - bufferSize: bufferSize, - addr: addr, + mtu: opts.MTU, + bufferSize: opts.BufferSize, + addr: opts.LinkAddress, + peerFD: opts.PeerFD, + onClosed: opts.OnClosed, } - if err := e.tx.init(bufferSize, &tx); err != nil { + if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil { return nil, err } - if err := e.rx.init(bufferSize, &rx); err != nil { + if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil { e.tx.cleanup() return nil, err } + e.caps = stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + e.caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + e.caps |= stack.CapabilityTXChecksumOffload + } + + if opts.LinkAddress != "" { + e.hdrSize = header.EthernetMinimumSize + e.caps |= stack.CapabilityResolutionRequired + } return e, nil } @@ -124,8 +228,8 @@ func (e *endpoint) Close() { // Cleanup the queues inline if the worker hasn't started yet; we also // know it won't start from now on because stopRequested is set to 1. e.mu.Lock() + defer e.mu.Unlock() workerPresent := e.workerStarted - e.mu.Unlock() if !workerPresent { e.tx.cleanup() @@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { e.workerStarted = true e.completed.Add(1) + + // Spin up a goroutine to monitor for peer shutdown. + if e.peerFD >= 0 { + e.completed.Add(1) + go func() { + defer e.completed.Done() + b := make([]byte, 1) + // When sharedmem endpoint is in use the peerFD is never used for any data + // transfer and this Read should only return if the peer is shutting down. + _, err := rawfile.BlockingRead(e.peerFD, b) + if e.onClosed != nil { + e.onClosed(err) + } + }() + } + // Link endpoints are not savable. When transportation endpoints // are saved, they stop sending outgoing packets and all // incoming packets are rejected. @@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool { // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized // during construction. func (e *endpoint) MTU() uint32 { - return e.mtu - header.EthernetMinimumSize + return e.mtu - e.hdrSize } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps } // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the // ethernet frame header size. -func (*endpoint) MaxHeaderLength() uint16 { - return header.EthernetMinimumSize +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) } // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local @@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WriteRawPacket implements stack.LinkEndpoint. func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) +// +checklocks:e.mu +func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } views := pkt.Views() // Transmit the packet. - e.mu.Lock() ok := e.tx.transmit(views...) - e.mu.Unlock() - if !ok { return &tcpip.ErrWouldBlock{} } @@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol return nil } +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + if err := e.writePacketLocked(r, protocol, pkt); err != nil { + return err + } + e.tx.notify() + return nil +} + // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + n := 0 + var err tcpip.Error + e.mu.Lock() + defer e.mu.Unlock() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil { + break + } + n++ + } + // WritePackets never returns an error if it successfully transmitted at least + // one packet. + if err != nil && n == 0 { + return 0, err + } + e.tx.notify() + return n, nil } // dispatchLoop reads packets from the rx queue in a loop and dispatches them @@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { Data: buffer.View(b).ToVectorisedView(), }) - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - continue + var src, dst tcpip.LinkAddress + var proto tcpip.NetworkProtocolNumber + if e.addr != "" { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + continue + } + eth := header.Ethernet(hdr) + src = eth.SourceAddress() + dst = eth.DestinationAddress() + proto = eth.Type() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data().PullUp(1) + if !ok { + continue + } + switch header.IPVersion(h) { + case header.IPv4Version: + proto = header.IPv4ProtocolNumber + case header.IPv6Version: + proto = header.IPv6ProtocolNumber + default: + continue + } } - eth := header.Ethernet(hdr) // Send packet up the stack. - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) + d.DeliverNetworkPacket(src, dst, proto, pkt) } + e.mu.Lock() + defer e.mu.Unlock() + // Clean state. e.tx.cleanup() e.rx.cleanup() -- cgit v1.2.3