summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD29
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go198
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go6
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go148
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go179
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go224
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go334
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go89
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go19
11 files changed, 1341 insertions, 107 deletions
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..6c35aeecf 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,25 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
"//pkg/log",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +32,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +45,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..5fa6d91f0
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,198 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var fd uintptr
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: int(fd),
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, _, errno := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
+ if errno != 0 {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", error(errno))
+ }
+ dataFD, err := createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: int(eventFD),
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ unix.Close(c.EventFD)
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..399317335 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -108,6 +108,12 @@ func (r *rx) cleanup() {
unix.Close(r.eventFD)
}
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ unix.Write(r.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+}
+
// postAndReceive posts the provided buffers (if any), and then tries to read
// from the receive queue.
//
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..2ad8bf650
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,148 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD int
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := unix.Dup(c.EventFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Close(efd) })
+
+ // Set the eventfd as non-blocking.
+ if err := unix.SetNonblock(efd, true); err != nil {
+ return err
+ }
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ unix.Close(s.eventFD)
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ var tmp [8]byte
+ rawfile.BlockingRead(s.eventFD, tmp[:])
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..9370b2a46
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,179 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD int
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := unix.Dup(c.EventFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Close(efd) })
+
+ // Set the eventfd as non-blocking.
+ if err := unix.SetNonblock(efd, true); err != nil {
+ return err
+ }
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ unix.Close(s.eventFD)
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ unix.Write(s.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 66efe6472..e2a8c4863 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,6 +24,7 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
"golang.org/x/sys/unix"
@@ -32,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD, q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: fds[1],
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -124,8 +228,8 @@ func (e *endpoint) Close() {
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WriteRawPacket implements stack.LinkEndpoint.
func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..16feb64b2
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,334 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..bb094da63 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..5ffcb8ab4 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -28,10 +28,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD int
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +66,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,6 +145,12 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ unix.Write(t.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+}
+
// getBuffer returns a memory region mapped to the full contents of the given
// file descriptor.
func getBuffer(fd int) ([]byte, error) {