diff options
-rw-r--r-- | pkg/tcpip/link/sharedmem/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/sharedmem.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/sharedmem_server.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/sharedmem_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/sharedmem_unsafe.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/tx.go | 16 |
6 files changed, 66 insertions, 30 deletions
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f8076d83c..af755473c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/cleanup", "//pkg/eventfd", "//pkg/log", + "//pkg/memutil", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index bcb37a465..b75522a51 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -343,10 +343,10 @@ func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkPr // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if err := e.writePacketLocked(r, protocol, pkt); err != nil { + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { return err } e.tx.notify() @@ -354,13 +354,13 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 var err tcpip.Error e.mu.Lock() defer e.mu.Unlock() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { break } n++ diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go index ccc84989d..43c5b8c63 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_server.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go @@ -218,14 +218,24 @@ func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcp eth.Encode(ethHdr) } -// WriteRawPacket implements stack.LinkEndpoint. -func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { - return &tcpip.ErrNotSupported{} +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket +func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error { + views := pkt.Views() + e.mu.Lock() + defer e.mu.Unlock() + ok := e.tx.transmit(views) + if !ok { + return &tcpip.ErrWouldBlock{} + } + e.tx.notify() + return nil } // +checklocks:e.mu func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + if e.addr != "" { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } views := pkt.Views() ok := e.tx.transmit(views) @@ -238,11 +248,12 @@ func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Transmit the packet. e.mu.Lock() defer e.mu.Unlock() - if err := e.writePacketLocked(r, protocol, pkt); err != nil { + if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { return err } e.tx.notify() @@ -250,13 +261,13 @@ func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkPr } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 var err tcpip.Error e.mu.Lock() defer e.mu.Unlock() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil { + if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { break } n++ diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 66ffc33b8..a49f5f87d 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -210,6 +210,7 @@ func TestSimpleSend(t *testing.T) { // Prepare route. var r stack.RouteInfo r.RemoteLinkAddress = remoteLinkAddr + r.LocalLinkAddress = localLinkAddr for iters := 1000; iters > 0; iters-- { func() { @@ -227,8 +228,11 @@ func TestSimpleSend(t *testing.T) { Data: data.ToVectorisedView(), }) copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -297,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) { // the minimum size of the ethernet header. ReserveHeaderBytes: header.EthernetMinimumSize, }) - proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + // Every PacketBuffer must have these set: + // See nic.writePacket. + pkt.EgressRoute = r + pkt.NetworkProtocolNumber = proto if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go index f7e816a41..d974c266e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go @@ -15,7 +15,12 @@ package sharedmem import ( + "fmt" + "reflect" "unsafe" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/memutil" ) // sharedDataPointer converts the shared data slice into a pointer so that it @@ -23,3 +28,31 @@ import ( func sharedDataPointer(sharedData []byte) *uint32 { return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) } + +// getBuffer returns a memory region mapped to the full contents of the given +// file descriptor. +func getBuffer(fd int) ([]byte, error) { + var s unix.Stat_t + if err := unix.Fstat(fd, &s); err != nil { + return nil, err + } + + // Check that size doesn't overflow an int. + if s.Size > int64(^uint(0)>>1) { + return nil, unix.EDOM + } + + addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/) + if err != nil { + return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err) + } + + // Use unsafe to conver addr into a []byte. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + hdr.Data = addr + hdr.Len = int(s.Size) + hdr.Cap = int(s.Size) + + return b, nil +} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 35e5bff12..d6c61afee 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -152,22 +152,6 @@ func (t *tx) notify() { t.eventFD.Notify() } -// getBuffer returns a memory region mapped to the full contents of the given -// file descriptor. -func getBuffer(fd int) ([]byte, error) { - var s unix.Stat_t - if err := unix.Fstat(fd, &s); err != nil { - return nil, err - } - - // Check that size doesn't overflow an int. - if s.Size > int64(^uint(0)>>1) { - return nil, unix.EDOM - } - - return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE) -} - // idDescriptor is used by idManager to either point to a tx buffer (in case // the ID is assigned) or to the next free element (if the id is not assigned). type idDescriptor struct { |