summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/metric/BUILD7
-rw-r--r--pkg/p9/client.go171
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/transport_flipcall.go13
-rw-r--r--pkg/seccomp/seccomp_unsafe.go2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/fs/host/tty.go15
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go3
-rw-r--r--pkg/sentry/fs/tty/master.go17
-rw-r--r--pkg/sentry/fs/tty/slave.go13
-rw-r--r--pkg/sentry/fs/tty/terminal.go92
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go8
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go24
-rw-r--r--pkg/sentry/hostcpu/BUILD1
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/kernel/BUILD8
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/sessions.go20
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/kernel/thread_group.go179
-rw-r--r--pkg/sentry/kernel/tty.go28
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go14
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go58
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/strace/BUILD7
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/vfs/file_description.go7
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go3
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go20
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go7
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go3
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go3
-rw-r--r--pkg/tcpip/network/arp/arp.go16
-rw-r--r--pkg/tcpip/network/arp/arp_test.go5
-rw-r--r--pkg/tcpip/network/ip_test.go13
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go44
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go9
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go15
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go34
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go5
-rw-r--r--pkg/tcpip/ports/ports.go114
-rw-r--r--pkg/tcpip/ports/ports_test.go113
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go5
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go49
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/registration.go45
-rw-r--r--pkg/tcpip/stack/stack.go137
-rw-r--r--pkg/tcpip/stack/stack_test.go175
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go227
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go45
-rw-r--r--pkg/tcpip/tcpip.go36
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go35
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go27
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go200
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go22
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go258
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go37
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go74
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go12
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go125
87 files changed, 2398 insertions, 1040 deletions
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index afa8f7659..c7f549428 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -8,6 +8,9 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/unet",
+ ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..7691b477b 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,6 +22,8 @@ import (
"runtime"
"sync/atomic"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
+// DialUnix connects to a Unix Domain Socket and return the file descriptor.
+func DialUnix(path string) (*FD, error) {
+ socket, err := unet.Connect(path, false)
+ return New(socket.FD()), err
+}
+
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 842788179..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,6 +1,7 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -22,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 123f54e29..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -92,6 +92,10 @@ type Client struct {
// version 0 implies 9P2000.L.
version uint32
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
// sendRecv is the transport function.
//
// This is determined dynamically based on whether or not the server
@@ -104,17 +108,15 @@ type Client struct {
// channelsMu protects channels.
channelsMu sync.Mutex
- // channelsWg is a wait group for active clients.
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
channelsWg sync.WaitGroup
- // channels are the set of initialized IPCs channels.
+ // channels is the set of all initialized channels.
channels []*channel
- // inuse is set when the channels are actually in use.
- //
- // This is a fixed-size slice, and the entries will be nil when the
- // corresponding channel is available.
- inuse []*channel
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
@@ -135,7 +137,7 @@ type Client struct {
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -214,13 +216,6 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
if len(c.channels) >= 1 {
// At least one channel created.
c.sendRecv = c.sendRecvChannel
-
- // If we are using channels for communication, then we must poll
- // for shutdown events on the main socket. If the socket happens
- // to shutdown, then we will close the channels as well. This is
- // necessary because channels can hang forever if the server dies
- // while we're expecting a response.
- go c.watch(socket) // S/R-SAFE: not relevant.
} else {
// Channel setup failed; fallback.
c.sendRecv = c.sendRecvLegacy
@@ -230,13 +225,20 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvLegacy
}
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
-// watch watches the given socket and calls Close on hang up events.
+// watch watches the given socket and releases resources on hangup events.
//
// This is intended to be called as a goroutine.
func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
events := []unix.PollFd{
unix.PollFd{
Fd: int32(socket.FD()),
@@ -244,19 +246,49 @@ func (c *Client) watch(socket *unet.Socket) {
},
}
+ // Wait for a shutdown event.
for {
- // Wait for a shutdown event.
n, err := unix.Ppoll(events, nil, nil)
- if n == 0 || err == syscall.EAGAIN {
+ if err == syscall.EINTR || err == syscall.EAGAIN {
continue
}
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
break
}
- // Close everything down: this will kick all active clients off any
- // pending requests. Note that Close must be safe to call concurrently,
- // and multiple times (see Close below).
- c.Close()
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
}
// openChannel attempts to open a client channel.
@@ -315,7 +347,7 @@ func (c *Client) openChannel(id int) error {
c.channelsMu.Lock()
defer c.channelsMu.Unlock()
c.channels = append(c.channels, res)
- c.inuse = append(c.inuse, nil)
+ c.availableChannels = append(c.availableChannels, res)
return nil
}
@@ -449,23 +481,16 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// sendRecvChannel uses channels to send a message.
func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
c.channelsMu.Lock()
- if len(c.channels) == 0 {
- // No channel available.
+ if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
return c.sendRecvLegacy(t, r)
}
-
- // Find the last used channel.
- //
- // Note that we must add one to the wait group while holding the
- // channel mutex, in order for the Wait operation to be race-free
- // below. The Wait operation shuts down all in use channels and
- // waits for them to return, but must do so holding the mutex.
- idx := len(c.channels) - 1
- ch := c.channels[idx]
- c.channels = c.channels[:idx]
- c.inuse[idx] = ch
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
c.channelsWg.Add(1)
c.channelsMu.Unlock()
@@ -473,8 +498,12 @@ func (c *Client) sendRecvChannel(t message, r message) error {
if !ch.connected {
ch.connected = true
if err := ch.data.Connect(); err != nil {
- // The channel is unusable, so don't return it.
- ch.Close()
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
c.channelsWg.Done()
return err
}
@@ -482,24 +511,17 @@ func (c *Client) sendRecvChannel(t message, r message) error {
// Send the message.
err := ch.sendRecv(c, t, r)
- if err != nil {
- // On shutdown, we'll see ENOENT. This is a normal situation, and
- // we shouldn't generate a spurious warning message in that case.
- log.Debugf("error calling sendRecvChannel: %v", err)
- }
- c.channelsWg.Done()
- // Return the channel.
- //
- // Note that we check the channel from the inuse slice here. This
- // prevents a race where Close is called, which clears inuse, and
- // means that we will not actually return the closed channel.
+ // Release the channel.
c.channelsMu.Lock()
- if c.inuse[idx] != nil {
- c.channels = append(c.channels, ch)
- c.inuse[idx] = nil
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
}
c.channelsMu.Unlock()
+ c.channelsWg.Done()
return err
}
@@ -510,44 +532,9 @@ func (c *Client) Version() uint32 {
}
// Close closes the underlying socket and channels.
-//
-// Because Close may be called asynchronously from watch, it must be
-// safe to call concurrently and multiple times.
-func (c *Client) Close() error {
- c.channelsMu.Lock()
- defer c.channelsMu.Unlock()
-
- // Close all inactive channels.
- for _, ch := range c.channels {
- ch.Shutdown()
- ch.Close()
- }
- // Close all active channels.
- for _, ch := range c.inuse {
- if ch != nil {
- log.Debugf("shutting down active channel@%p...", ch)
- ch.Shutdown()
- }
- }
-
- // Wait for active users.
- c.channelsWg.Wait()
-
- // Close all previously active channels.
- for i, ch := range c.inuse {
- if ch != nil {
- ch.Close()
-
- // Clear the inuse entry here so that it will not be returned
- // to the channel slice, which is cleared below. See the
- // comment at the end of sendRecvChannel.
- c.inuse[i] = nil
- }
- }
- c.channels = nil // Prevent use again.
-
- // Close the main socket. Note that operation is safe to be called
- // multiple times, unlikely the channel Close operations above, which
- // we are careful to ensure aren't called twice.
- return c.socket.Close()
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 1d34181e0..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 9d74638bb..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index aebb54959..7cdf4ecc3 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -60,6 +60,7 @@ type channel struct {
// -- client only --
connected bool
+ active bool
// -- server only --
client *fd.FD
@@ -197,10 +198,18 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
return nil, &ErrBadResponse{Got: t, Want: r.Type()}
}
- // Is there a payload? Set to the latter portion.
+ // Is there a payload? Copy from the latter portion.
if payloader, ok := r.(payloader); ok {
fs := payloader.FixedSize()
- payloader.SetPayload(ch.buf.data[fs:])
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
ch.buf.data = ch.buf.data[:fs]
}
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index 0a3d92854..be328db12 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -35,7 +35,7 @@ type sockFprog struct {
//go:nosplit
func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
// PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
return errno
}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 2526412a4..90331e3b2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -43,12 +43,15 @@ type TTYFileOperations struct {
// fgProcessGroup is the foreground process group that is currently
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+
+ termios linux.KernelTermios
}
// newTTYFile returns a new fs.File that wraps a TTY FD.
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
})
}
@@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
t.mu.Lock()
defer t.mu.Unlock()
- // Are we allowed to do the write?
- if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
- return 0, err
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
}
return t.fileOperations.Write(ctx, file, src, offset)
}
@@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
return 0, err
}
err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
return 0, err
case linux.TIOCGPGRP:
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index d799de748..25811f668 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -25,6 +25,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 1d128532b..2f639c823 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
// Release implements fs.InodeOperations.Release.
func (d *dirInodeOperations) Release(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
d.master.DecRef()
if len(d.slaves) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 92ec1ca18..19b7557d5 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme
return 0, mf.t.ld.windowSize(ctx, io, args)
case linux.TIOCSWINSZ:
return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
@@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TCSETS,
linux.TCSETSW,
linux.TCSETSF,
- linux.TIOCGPGRP,
- linux.TIOCSPGRP,
linux.TIOCGWINSZ,
linux.TIOCSWINSZ,
linux.TIOCSETD,
@@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TIOCEXCL,
linux.TIOCNXCL,
linux.TIOCGEXCL,
- linux.TIOCNOTTY,
- linux.TIOCSCTTY,
linux.TIOCGSID,
linux.TIOCGETD,
linux.TIOCVHANGUP,
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index e30266404..944c4ada1 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- // TODO(b/129283598): Implement once we have support for job
- // control.
- return 0, nil
+ return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index b7cecb2ed..ff8138820 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -17,7 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// Terminal is a pseudoterminal.
@@ -26,23 +29,100 @@ import (
type Terminal struct {
refs.AtomicRefCount
- // n is the terminal index.
+ // n is the terminal index. It is immutable.
n uint32
- // d is the containing directory.
+ // d is the containing directory. It is immutable.
d *dirInodeOperations
- // ld is the line discipline of the terminal.
+ // ld is the line discipline of the terminal. It is immutable.
ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
}
func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
termios := linux.DefaultSlaveTermios
t := Terminal{
- d: d,
- n: n,
- ld: newLineDiscipline(termios),
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{},
+ slaveKTTY: &kernel.TTY{},
}
t.EnableLeakCheck("tty.Terminal")
return &t
}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..0b471d121 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 63cf7aeaf..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..c620227c9 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index d4a420e60..359468ccc 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -7,6 +7,7 @@ go_library(
name = "hostcpu",
srcs = [
"getcpu_amd64.s",
+ "getcpu_arm64.s",
"hostcpu.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e964a991b..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,5 +1,6 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -84,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
@@ -145,6 +152,7 @@ go_library(
"threads.go",
"timekeeper.go",
"timekeeper_state.go",
+ "tty.go",
"uts_namespace.go",
"vdso.go",
"version.go",
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 81fcd8258..047b5214d 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -47,6 +47,11 @@ type Session struct {
// The id is immutable.
id SessionID
+ // foreground is the foreground process group.
+ //
+ // This is protected by TaskSet.mu.
+ foreground *ProcessGroup
+
// ProcessGroups is a list of process groups in this Session. This is
// protected by TaskSet.mu.
processGroups processGroupList
@@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
func (tg *ThreadGroup) CreateSession() error {
tg.pidns.owner.mu.Lock()
defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
return tg.createSession()
}
// createSession creates a new session for a threadgroup.
//
-// Precondition: callers must hold TaskSet.mu for writing.
+// Precondition: callers must hold TaskSet.mu and the signal mutex for writing.
func (tg *ThreadGroup) createSession() error {
// Get the ID for this thread in the current namespace.
id := tg.pidns.tgids[tg]
@@ -321,8 +328,14 @@ func (tg *ThreadGroup) createSession() error {
childTG.processGroup.incRefWithParent(pg)
childTG.processGroup.decRefWithParent(oldParentPG)
})
- tg.processGroup.decRefWithParent(oldParentPG)
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
} else {
// The current process group may be nil only in the case of an
// unparented thread group (i.e. the init process). This would
@@ -346,6 +359,9 @@ func (tg *ThreadGroup) createSession() error {
ns.processGroups[ProcessGroupID(local)] = pg
}
+ // Disconnect from the controlling terminal.
+ tg.tty = nil
+
return nil
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index d60cd62c7..ae6fc4025 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
if parentPG := tg.parentPG(); parentPG == nil {
tg.createSession()
} else {
- // Inherit the process group.
+ // Inherit the process group and terminal.
parentPG.incRefWithParent(parentPG)
tg.processGroup = parentPG
+ tg.tty = t.parent.tg.tty
}
}
tg.tasks.PushBack(t)
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 2a97e3e8e..0eef24bfb 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -19,10 +19,13 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A ThreadGroup is a logical grouping of tasks that has widespread
@@ -245,6 +248,12 @@ type ThreadGroup struct {
//
// mounts is immutable.
mounts *fs.MountNamespace
+
+ // tty is the thread group's controlling terminal. If nil, there is no
+ // controlling terminal.
+ //
+ // tty is protected by the signal mutex.
+ tty *TTY
}
// newThreadGroup returns a new, empty thread group in PID namespace ns. The
@@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) {
}
}
+// SetControllingTTY sets tty as the controlling terminal of tg.
+func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "The calling process must be a session leader and not have a
+ // controlling terminal already." - tty_ioctl(4)
+ if tg.processGroup.session.leader != tg || tg.tty != nil {
+ return syserror.EINVAL
+ }
+
+ // "If this terminal is already the controlling terminal of a different
+ // session group, then the ioctl fails with EPERM, unless the caller
+ // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the
+ // terminal is stolen, and all processes that had it as controlling
+ // terminal lose it." - tty_ioctl(4)
+ if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
+ if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 {
+ return syserror.EPERM
+ }
+ // Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
+ for othertg := range tg.pidns.owner.Root.tgids {
+ // This won't deadlock by locking tg.signalHandlers
+ // because at this point:
+ // - We only lock signalHandlers if it's in the same
+ // session as the tty's controlling thread group.
+ // - We know that the calling thread group is not in
+ // the same session as the tty's controlling thread
+ // group.
+ if othertg.processGroup.session == tty.tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+ }
+
+ // Set the controlling terminal and foreground process group.
+ tg.tty = tty
+ tg.processGroup.session.foreground = tg.processGroup
+ // Set this as the controlling process of the terminal.
+ tty.tg = tg
+
+ return nil
+}
+
+// ReleaseControllingTTY gives up tty as the controlling tty of tg.
+func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ // Just below, we may re-lock signalHandlers in order to send signals.
+ // Thus we can't defer Unlock here.
+ tg.signalHandlers.mu.Lock()
+
+ if tg.tty == nil || tg.tty != tty {
+ tg.signalHandlers.mu.Unlock()
+ return syserror.ENOTTY
+ }
+
+ // "If the process was session leader, then send SIGHUP and SIGCONT to
+ // the foreground process group and all processes in the current
+ // session lose their controlling terminal." - tty_ioctl(4)
+ // Remove tty as the controlling tty for each process in the session,
+ // then send them SIGHUP and SIGCONT.
+
+ // If we're not the session leader, we don't have to do much.
+ if tty.tg != tg {
+ tg.tty = nil
+ tg.signalHandlers.mu.Unlock()
+ return nil
+ }
+
+ tg.signalHandlers.mu.Unlock()
+
+ // We're the session leader. SIGHUP and SIGCONT the foreground process
+ // group and remove all controlling terminals in the session.
+ var lastErr error
+ for othertg := range tg.pidns.owner.Root.tgids {
+ if othertg.processGroup.session == tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ if othertg.processGroup == tg.processGroup.session.foreground {
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ }
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+
+ return lastErr
+}
+
+// ForegroundProcessGroup returns the process group ID of the foreground
+// process group.
+func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "When fd does not refer to the controlling terminal of the calling
+ // process, -1 is returned" - tcgetpgrp(3)
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ return int32(tg.processGroup.session.foreground.id), nil
+}
+
+// SetForegroundProcessGroup sets the foreground process group of tty to pgid.
+func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is
+ // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
+ // members of this background process group."
+
+ // tty must be the controlling terminal.
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ // pgid must be positive.
+ if pgid < 0 {
+ return -1, syserror.EINVAL
+ }
+
+ // pg must not be empty. Empty process groups are removed from their
+ // pid namespaces.
+ pg, ok := tg.pidns.processGroups[pgid]
+ if !ok {
+ return -1, syserror.ESRCH
+ }
+
+ // pg must be part of this process's session.
+ if tg.processGroup.session != pg.session {
+ return -1, syserror.EPERM
+ }
+
+ tg.processGroup.session.foreground.id = pgid
+ return 0, nil
+}
+
// itimerRealListener implements ktime.Listener for ITIMER_REAL expirations.
//
// +stateify savable
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
new file mode 100644
index 000000000..34f84487a
--- /dev/null
+++ b/pkg/sentry/kernel/tty.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync"
+
+// TTY defines the relationship between a thread group and its controlling
+// terminal.
+//
+// +stateify savable
+type TTY struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // tg is protected by mu.
+ tg *ThreadGroup
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 4f8f9c5d9..9f0ecfbe4 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread {
// attach attaches to the thread.
func (t *thread) attach() {
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("unable to attach: %v", errno))
}
@@ -417,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
for {
// Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -435,7 +435,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
// Complete the actual system call.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -526,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
for {
// Start running until the next system call.
if isSingleStepping(regs) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU_SINGLESTEP,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index f09b0b3d0..c075b5f91 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -53,7 +53,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// Enable cpuid-faulting; this may fail on older kernels or hardware,
// so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 3e05e40fe..5812085fa 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -209,6 +209,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -415,13 +419,13 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// WriteTo implements fs.FileOperations.WriteTo.
func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
- defer s.readMu.Unlock()
// Copy as much data as possible.
done := int64(0)
for count > 0 {
// This may return a blocking error.
if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
return done, err.ToError()
}
@@ -434,16 +438,18 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
// supported by any Linux system calls, but the
// expectation is that now a caller will call read to
// actually remove these bytes from the socket.
- return done, nil
+ break
}
// Drop that part of the view.
s.readView.TrimFront(n)
if err != nil {
+ s.readMu.Unlock()
return done, err
}
}
+ s.readMu.Unlock()
return done, nil
}
@@ -549,7 +555,11 @@ func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
// ReadFrom implements fs.FileOperations.ReadFrom.
func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := &readerPayload{ctx: ctx, r: r, count: count}
- n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
if err == tcpip.ErrWouldBlock {
return 0, syserror.ErrWouldBlock
}
@@ -561,9 +571,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
- // Reads may be destructive but should be very fast,
- // so we can't release the lock while copying data.
- Atomic: true,
+ Atomic: true, // See above.
})
}
if err == tcpip.ErrWouldBlock {
@@ -883,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -899,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -934,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1271,7 +1292,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1279,7 +1300,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1297,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2313,9 +2341,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 421f93dc4..0a9dfa6c3 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..1867b3a5c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..7d7b42eba 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 18d24ab61..61acd0abd 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -232,7 +232,7 @@ var AMD64 = &kernel.SyscallTable{
184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
- 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 187: syscalls.Supported("readahead", Readahead),
188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 3ab54271c..cd31e0649 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
}
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
// Pread64 implements linux syscall pread64(2).
func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..7eb2b2821 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 672f026b2..8ced960bb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index eec430d0a..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -133,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index adcf21371..584db710e 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,7 @@ const (
PacketMMap
)
+// An endpoint implements the link-layer using a message-oriented file descriptor.
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +116,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -164,7 +169,8 @@ type Options struct {
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
@@ -290,7 +296,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +330,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index e121ea1a5..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3ed7b98d1..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,6 +104,13 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index ba387af73..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e7b6d7912..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 408cc62f7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -120,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 1031438b1..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,6 +70,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index fd6395fc1..26cf1c528 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -200,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 387fca96e..88b57ec03 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,7 +44,10 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
ep := channel.New(256, defaultMTU, stackLinkAddr)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4b3bd74fa..a9741622e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
@@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..b7b07a6c1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -53,6 +50,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
}
// Number returns the ipv4 protocol number.
@@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index ae827ca27..b6641ccc3 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,7 +33,10 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
@@ -238,7 +241,9 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
s.CreateNIC(1, ep)
const (
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 653d984e9..01f5a17ec 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -81,7 +81,10 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -205,8 +208,14 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..7de6a4546 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,9 +14,9 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
@@ -28,9 +28,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -160,14 +157,6 @@ func (*endpoint) Close() {}
type protocol struct{}
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
-}
-
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
@@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 57bcd5455..78c674c2c 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -124,17 +124,20 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
tests := []struct {
- name string
- protocolName string
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.ProtocolName6, testReceiveICMP},
- {"UDP", udp.ProtocolName, testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -152,19 +155,22 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
tests := []struct {
- name string
- protocolName string
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.ProtocolName6, testReceiveICMP},
- {"UDP", udp.ProtocolName, testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
}
snmc := header.SolicitedNodeAddr(addr2)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -237,7 +243,9 @@ func TestAddIpv6Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 571915d3f..e30791fe3 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -31,7 +31,10 @@ import (
func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..40e202717 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -47,43 +47,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
-
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
+
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
return true
}
@@ -116,17 +149,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +171,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +186,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +204,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +222,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..a67e283f1 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -34,6 +34,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +101,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +215,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index f12189580..2239c1e66 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -126,7 +126,10 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 329941775..bca73cbb1 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -111,7 +111,10 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28c49e8ff..3842f1f7d 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -54,6 +54,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -64,6 +65,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index f8156be47..3a20839da 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"golang.org/x/time/rate"
)
@@ -33,54 +31,11 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- mu sync.RWMutex
- l *rate.Limiter
+ *rate.Limiter
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
// at which ICMP messages are generated by the stack.
func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)}
-}
-
-// Allow returns true if we are allowed to send at least 1 message at the
-// moment.
-func (i *ICMPRateLimiter) Allow() bool {
- i.mu.RLock()
- allow := i.l.Allow()
- i.mu.RUnlock()
- return allow
-}
-
-// Limit returns the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) Limit() rate.Limit {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Limit()
-}
-
-// SetLimit sets the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) {
- i.mu.RLock()
- defer i.mu.RUnlock()
- i.l.SetLimit(newLimit)
-}
-
-// Burst returns how many ICMP messages can be sent at any single instant.
-func (i *ICMPRateLimiter) Burst() int {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Burst()
-}
-
-// SetBurst sets the maximum number of ICMP messages allowed at any single
-// instant.
-//
-// NOTE: Changing Burst causes the underlying rate limiter to be recreated.
-func (i *ICMPRateLimiter) SetBurst(burst int) {
- i.mu.Lock()
- i.l = rate.NewLimiter(i.l.Limit(), burst)
- i.mu.Unlock()
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a719058b4..f6106f762 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,8 +34,6 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
mu sync.RWMutex
spoofing bool
promiscuous bool
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -148,37 +145,6 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var r *referencedNetworkEndpoint
-
- // Check for a primary endpoint.
- if list, ok := n.primary[protocol]; ok {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- if ref.getKind() == permanent && ref.tryIncRef() {
- r = ref
- break
- }
- }
-
- }
-
- if r == nil {
- return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
- }
-
- addressWithPrefix := tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
- r.decRef()
-
- return addressWithPrefix, nil
-}
-
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -398,10 +364,12 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return err
}
-// Addresses returns the addresses associated with this NIC.
-func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
+
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
// Don't include expired or tempory endpoints to avoid confusion and
@@ -421,6 +389,34 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
return addrs
}
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.primary {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
// AddAddressRange adds a range of addresses to n, so that it starts accepting
// packets targeted at the given addresses and network protocol. The range is
// given by a subnet address, and all addresses contained in the subnet are
@@ -708,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -724,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -768,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 88a698b18..80101d4bb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -295,6 +295,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -357,14 +366,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -372,34 +373,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 1fe21b68e..6a8079823 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,11 +17,6 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
@@ -351,6 +346,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -359,10 +357,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -398,6 +392,12 @@ type Stack struct {
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -411,8 +411,9 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
// New allocates a new networking stack with only the requested networking and
@@ -422,7 +423,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -438,17 +439,11 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
icmpRateLimiter: NewICMPRateLimiter(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -456,18 +451,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -602,7 +593,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
@@ -738,7 +729,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.Addresses(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -845,19 +836,37 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
-// GetMainNICAddress returns the first primary address (and the subnet that
-// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
-// address if no primary addresses exist. Returns an error if the NIC doesn't
-// exist or has no endpoints.
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
+ }
+ return nics
+}
+
+// GetMainNICAddress returns the first primary address and prefix for the given
+// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
+// value if the NIC doesn't have a primary address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.getMainNICAddress(protocol)
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ for _, a := range nic.PrimaryAddresses() {
+ if a.Protocol == protocol {
+ return a.AddressWithPrefix, nil
+ }
+ }
+ return tcpip.AddressWithPrefix{}, nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1024,73 +1033,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 0c26c9911..d2dede8a9 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -222,11 +222,17 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
ep := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -370,7 +376,9 @@ func TestNetworkSend(t *testing.T) {
// address: 1. The route table sends all packets through the only
// existing nic.
ep := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -395,7 +403,9 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -476,7 +486,9 @@ func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -554,7 +566,9 @@ func TestAddressRemoval(t *testing.T) {
localAddr := tcpip.Address([]byte{localAddrByte})
remoteAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -599,7 +613,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
localAddr := tcpip.Address([]byte{localAddrByte})
remoteAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -688,7 +704,9 @@ func TestEndpointExpiration(t *testing.T) {
for _, promiscuous := range []bool{true, false} {
for _, spoofing := range []bool{true, false} {
t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
@@ -844,7 +862,9 @@ func TestEndpointExpiration(t *testing.T) {
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -894,7 +914,9 @@ func TestSpoofingWithAddress(t *testing.T) {
nonExistentLocalAddr := tcpip.Address("\x02")
dstAddr := tcpip.Address("\x03")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -930,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
testSendTo(t, s, dstAddr, ep, nil)
@@ -945,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != localAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet using the route works.
testSend(t, r, ep, nil)
@@ -958,7 +980,9 @@ func TestSpoofingNoAddress(t *testing.T) {
nonExistentLocalAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -992,10 +1016,10 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
@@ -1003,7 +1027,9 @@ func TestSpoofingNoAddress(t *testing.T) {
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1074,7 +1100,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
{"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
} {
t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1130,7 +1158,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
// Add a range of addresses, then check that a packet is delivered.
func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1196,7 +1226,9 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub
// existent.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -1234,7 +1266,9 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
// Set a range of addresses, then send a packet to a destination outside the
// range and then check it doesn't get delivered.
func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1266,7 +1300,10 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
}
func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -1319,7 +1356,9 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S
}
func TestAddresRangeAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1360,7 +1399,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1400,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
}
} else {
- // At least one primary address was added, expect a valid
- // address and prefixLen.
- gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
- t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
}
}
})
@@ -1425,7 +1466,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1452,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get an error after removal.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
})
}
@@ -1479,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address {
}
func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
}
sort.Slice(gotAddresses, func(i, j int) bool {
@@ -1500,7 +1551,9 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1519,13 +1572,15 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1551,13 +1606,15 @@ func TestAddProtocolAddress(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1580,13 +1637,15 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1615,12 +1674,14 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed: ", err)
@@ -1666,7 +1727,9 @@ func TestNICStats(t *testing.T) {
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
s.SetForwarding(true)
ep1 := channel.New(10, defaultMTU, "")
@@ -1714,9 +1777,3 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..8c768c299 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
+ destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
@@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
return ep
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 847d02982..842a16277 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -122,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -163,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -277,9 +283,16 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
func TestTransportReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -341,7 +354,10 @@ func TestTransportReceive(t *testing.T) {
func TestTransportControlReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -409,7 +425,10 @@ func TestTransportControlReceive(t *testing.T) {
func TestTransportSend(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -452,7 +471,10 @@ func TestTransportSend(t *testing.T) {
}
func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -493,7 +515,10 @@ func TestTransportOptions(t *testing.T) {
}
func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
@@ -571,9 +596,3 @@ func TestTransportForwarding(t *testing.T) {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2534069ab..faaa4a4e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -401,6 +401,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -446,10 +450,22 @@ type WriteOptions struct {
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -458,18 +474,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -491,6 +495,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 3db060384..a3a910d41 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -104,7 +104,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -538,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 1eb790932..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,10 +14,9 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
@@ -34,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index cf1c5c433..a02731a5d 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -492,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -504,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
ep.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSize
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
@@ -515,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 0802e984e..3ae4a5426 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index dd931f88c..a1cd0d481 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -280,6 +280,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -564,11 +567,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +628,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -952,62 +955,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1071,6 +1021,82 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -1182,6 +1208,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1204,18 +1242,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1252,6 +1278,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1458,7 +1494,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1472,13 +1508,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
e.id = id
return true, nil
@@ -1496,7 +1534,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1640,7 +1678,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1721,7 +1759,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1731,16 +1769,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.id.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a13b2022..d5d8ab96a 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..9fa97528b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 32bb45224..089826a88 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
@@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -465,11 +465,71 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -557,8 +617,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -631,7 +690,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -700,7 +759,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -785,7 +844,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -804,8 +863,7 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -872,11 +930,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -976,7 +1032,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1017,7 +1073,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1075,8 +1131,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1110,8 +1165,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1151,7 +1205,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1224,7 +1278,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1293,8 +1347,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1399,7 +1452,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1449,7 +1502,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1497,7 +1550,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1579,7 +1632,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1695,7 +1748,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -1704,7 +1757,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1727,7 +1780,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1871,7 +1924,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1973,7 +2026,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2010,7 +2063,7 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2035,7 +2088,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2078,7 +2131,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2132,7 +2185,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2203,7 +2256,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2291,7 +2344,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2377,7 +2430,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2509,7 +2562,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2559,7 +2612,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2580,7 +2633,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2604,7 +2657,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2635,7 +2688,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2681,7 +2734,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2856,8 +2909,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2869,8 +2922,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2880,7 +2933,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2926,7 +2982,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2945,37 +3004,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3231,7 +3349,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3308,7 +3426,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3482,7 +3600,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 16783e716..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -155,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ if testing.Verbose() {
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ }
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -512,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -585,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -598,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c1ca22b35..7a635ab8d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -52,6 +52,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6ac7c067a..52f5af777 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -88,6 +88,7 @@ type endpoint struct {
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
// shutdownFlags represent the current shutdown state of the endpoint.
@@ -144,8 +145,8 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
@@ -389,7 +390,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -546,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -568,7 +589,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -578,18 +612,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -640,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -747,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.id.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.id = id
e.route.Release()
e.route = stack.Route{}
@@ -829,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
}
e.id = id
@@ -892,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a9edc2c8d..2d0bc5221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 068d9a272..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -182,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c6deab892..5059ca22d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,7 +274,10 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -473,87 +476,59 @@ func newMinPayload(minSize int) []byte {
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}