summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go8
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go27
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go16
6 files changed, 66 insertions, 30 deletions
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index f8076d83c..af755473c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/cleanup",
"//pkg/eventfd",
"//pkg/log",
+ "//pkg/memutil",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index bcb37a465..b75522a51 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -343,10 +343,10 @@ func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkPr
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
return err
}
e.tx.notify()
@@ -354,13 +354,13 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
var err tcpip.Error
e.mu.Lock()
defer e.mu.Unlock()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
break
}
n++
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
index ccc84989d..43c5b8c63 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_server.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -218,14 +218,24 @@ func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcp
eth.Encode(ethHdr)
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket
+func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ views := pkt.Views()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+ e.tx.notify()
+ return nil
}
// +checklocks:e.mu
func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
ok := e.tx.transmit(views)
@@ -238,11 +248,12 @@ func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// Transmit the packet.
e.mu.Lock()
defer e.mu.Unlock()
- if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
return err
}
e.tx.notify()
@@ -250,13 +261,13 @@ func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkPr
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
var err tcpip.Error
e.mu.Lock()
defer e.mu.Unlock()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
break
}
n++
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 66ffc33b8..a49f5f87d 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -210,6 +210,7 @@ func TestSimpleSend(t *testing.T) {
// Prepare route.
var r stack.RouteInfo
r.RemoteLinkAddress = remoteLinkAddr
+ r.LocalLinkAddress = localLinkAddr
for iters := 1000; iters > 0; iters-- {
func() {
@@ -227,8 +228,11 @@ func TestSimpleSend(t *testing.T) {
Data: data.ToVectorisedView(),
})
copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -297,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
// the minimum size of the ethernet header.
ReserveHeaderBytes: header.EthernetMinimumSize,
})
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
index f7e816a41..d974c266e 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -15,7 +15,12 @@
package sharedmem
import (
+ "fmt"
+ "reflect"
"unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/memutil"
)
// sharedDataPointer converts the shared data slice into a pointer so that it
@@ -23,3 +28,31 @@ import (
func sharedDataPointer(sharedData []byte) *uint32 {
return (*uint32)(unsafe.Pointer(&sharedData[0:4][0]))
}
+
+// getBuffer returns a memory region mapped to the full contents of the given
+// file descriptor.
+func getBuffer(fd int) ([]byte, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(fd, &s); err != nil {
+ return nil, err
+ }
+
+ // Check that size doesn't overflow an int.
+ if s.Size > int64(^uint(0)>>1) {
+ return nil, unix.EDOM
+ }
+
+ addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/)
+ if err != nil {
+ return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err)
+ }
+
+ // Use unsafe to conver addr into a []byte.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ hdr.Data = addr
+ hdr.Len = int(s.Size)
+ hdr.Cap = int(s.Size)
+
+ return b, nil
+}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index 35e5bff12..d6c61afee 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -152,22 +152,6 @@ func (t *tx) notify() {
t.eventFD.Notify()
}
-// getBuffer returns a memory region mapped to the full contents of the given
-// file descriptor.
-func getBuffer(fd int) ([]byte, error) {
- var s unix.Stat_t
- if err := unix.Fstat(fd, &s); err != nil {
- return nil, err
- }
-
- // Check that size doesn't overflow an int.
- if s.Size > int64(^uint(0)>>1) {
- return nil, unix.EDOM
- }
-
- return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE)
-}
-
// idDescriptor is used by idManager to either point to a tx buffer (in case
// the ID is assigned) or to the next free element (if the id is not assigned).
type idDescriptor struct {