diff options
Diffstat (limited to 'pkg')
190 files changed, 3473 insertions, 1069 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ba233b93f..f45934466 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -2,9 +2,11 @@ # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "linux", @@ -44,6 +46,7 @@ go_library( "sem.go", "shm.go", "signal.go", + "signalfd.go", "socket.go", "splice.go", "tcp.go", diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go new file mode 100644 index 000000000..85fad9956 --- /dev/null +++ b/pkg/abi/linux/signalfd.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + // SFD_NONBLOCK is a signalfd(2) flag. + SFD_NONBLOCK = 00004000 + + // SFD_CLOEXEC is a signalfd(2) flag. + SFD_CLOEXEC = 02000000 +) + +// SignalfdSiginfo is the siginfo encoding for signalfds. +type SignalfdSiginfo struct { + Signo uint32 + Errno int32 + Code int32 + PID uint32 + UID uint32 + FD int32 + TID uint32 + Band uint32 + Overrun uint32 + TrapNo uint32 + Status int32 + Int int32 + Ptr uint64 + UTime uint64 + STime uint64 + Addr uint64 + AddrLSB uint16 + _ [48]uint8 +} diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 39d253b98..6bc486b62 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 47ab65346..5f59866fa 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 09d6c2c1f..543fb54bf 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 0c2dde4f8..51967b811 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index b692aa3b1..8d31e068c 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "bpf", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index cdec96df1..a0b21d4bd 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 830e19e07..32422f9e2 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "cpuid", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9961baaa9..71f2abc83 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index 785c685a0..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index e54e7371c..56495cbd9 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index bd1d614b6..5643d5f26 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -18,6 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", + "//third_party/gvsync", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index d59159912..8390915a2 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { *ep.dataLen() = w.Len() // Return control to the client. + raceBecomeInactive() if err := ep.futexSwitchToPeer(); err != nil { return err } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 991018684..386cee42c 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -180,7 +180,11 @@ const ( // Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), // ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - return ep.ctrlConnect() + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then @@ -192,6 +196,7 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -218,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then // they can only shoot themselves in the foot. *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlRoundTrip(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -240,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlWakeLast(); err != nil { return err } diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 435e4eeae..168a487ec 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -62,6 +62,9 @@ func (c *testConnection) destroy() { } func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 var serverRun sync.WaitGroup serverRun.Add(1) go func() { @@ -71,11 +74,19 @@ func testSendRecv(t *testing.T, c *testConnection) { t.Errorf("server Endpoint.RecvFirst() failed: %v", err) return } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) + } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { t.Errorf("server Endpoint.SendRecv() failed: %v", err) return } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) + } t.Logf("server Endpoint got packet 3") }() defer func() { @@ -89,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } t.Logf("client Endpoint sending packet 1 and waiting for packet 2") if _, err := c.clientEP.SendRecv(0); err != nil { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } t.Logf("client Endpoint got packet 2, sending packet 3") if err := c.clientEP.SendLast(0); err != nil { t.Fatalf("client Endpoint.SendLast() failed: %v", err) diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 73e6eef29..a37952637 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -17,6 +17,8 @@ package flipcall import ( "reflect" "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte { bsReflect.Cap = int(ep.dataCap) return bs } + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if gvsync.RaceEnabled { + gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if gvsync.RaceEnabled { + gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 11716af81..0c5f50397 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index e6a8dbd02..4b9321711 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 8f3defa25..34d2673ef 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,5 +1,6 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index c8e923a74..a5d980d14 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 12615240c..fc5f5779b 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 3b8a691f4..dd6ca6d39 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,5 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -21,6 +23,12 @@ proto_library( visibility = ["//:sandbox"], ) +cc_proto_library( + name = "metric_cc_proto", + visibility = ["//:sandbox"], + deps = [":metric_proto"], +) + go_proto_library( name = "metric_go_proto", importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index c6737bf97..f32244c69 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], @@ -19,11 +20,14 @@ go_library( "pool.go", "server.go", "transport.go", + "transport_flipcall.go", "version.go", ], importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", + "//pkg/fdchannel", + "//pkg/flipcall", "//pkg/log", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 7dc20aeef..2412aa5e1 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -20,6 +20,8 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -77,6 +79,47 @@ type Client struct { // fidPool is the collection of available fids. fidPool pool + // messageSize is the maximum total size of a message. + messageSize uint32 + + // payloadSize is the maximum payload size of a read or write. + // + // For large reads and writes this means that the read or write is + // broken up into buffer-size/payloadSize requests. + payloadSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + + // sendRecv is the transport function. + // + // This is determined dynamically based on whether or not the server + // supports flipcall channels (preferred as it is faster and more + // efficient, and does not require tags). + sendRecv func(message, message) error + + // -- below corresponds to sendRecvChannel -- + + // channelsMu protects channels. + channelsMu sync.Mutex + + // channelsWg counts the number of channels for which channel.active == + // true. + channelsWg sync.WaitGroup + + // channels is the set of all initialized channels. + channels []*channel + + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel + + // -- below corresponds to sendRecvLegacy -- + // pending is the set of pending messages. pending map[Tag]*response pendingMu sync.Mutex @@ -89,25 +132,12 @@ type Client struct { // Whoever writes to this channel is permitted to call recv. When // finished calling recv, this channel should be emptied. recvr chan bool - - // messageSize is the maximum total size of a message. - messageSize uint32 - - // payloadSize is the maximum payload size of a read or write - // request. For large reads and writes this means that the - // read or write is broken up into buffer-size/payloadSize - // requests. - payloadSize uint32 - - // version is the agreed upon version X of 9P2000.L.Google.X. - // version 0 implies 9P2000.L. - version uint32 } // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client return nil, ErrBadVersionString } for { + // Always exchange the version using the legacy version of the + // protocol. If the protocol supports flipcall, then we switch + // our sendRecv function to use that functionality. Otherwise, + // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion) + err := c.sendRecvLegacy(&Tversion{ + Version: versionString(requested), + MSize: messageSize, + }, &rversion) // The server told us to try again with a lower version. if err == syscall.EAGAIN { @@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.version = version break } + + // Can we switch to use the more advanced channels and create + // independent channels for communication? Prefer it if possible. + if versionSupportsFlipcall(c.version) { + // Attempt to initialize IPC-based communication. + for i := 0; i < channelsPerClient; i++ { + if err := c.openChannel(i); err != nil { + log.Warningf("error opening flipcall channel: %v", err) + break // Stop. + } + } + if len(c.channels) >= 1 { + // At least one channel created. + c.sendRecv = c.sendRecvChannel + } else { + // Channel setup failed; fallback. + c.sendRecv = c.sendRecvLegacy + } + } else { + // No channels available: use the legacy mechanism. + c.sendRecv = c.sendRecvLegacy + } + + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } +// watch watches the given socket and releases resources on hangup events. +// +// This is intended to be called as a goroutine. +func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + + events := []unix.PollFd{ + unix.PollFd{ + Fd: int32(socket.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + continue + } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() +} + +// openChannel attempts to open a client channel. +// +// Note that this function returns naked errors which should not be propagated +// directly to a caller. It is expected that the errors will be logged and a +// fallback path will be used instead. +func (c *Client) openChannel(id int) error { + var ( + rchannel0 Rchannel + rchannel1 Rchannel + res = new(channel) + ) + + // Open the data channel. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 0, + }, &rchannel0); err != nil { + return fmt.Errorf("error handling Tchannel message: %v", err) + } + if rchannel0.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on primary channel") + } + + // We don't need to hold this. + defer rchannel0.FilePayload().Close() + + // Open the channel for file descriptors. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 1, + }, &rchannel1); err != nil { + return err + } + if rchannel1.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on file descriptor channel") + } + + // Construct the endpoints. + res.desc = flipcall.PacketWindowDescriptor{ + FD: rchannel0.FilePayload().FD(), + Offset: int64(rchannel0.Offset), + Length: int(rchannel0.Length), + } + if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { + rchannel1.FilePayload().Close() + return err + } + + // The fds channel owns the control payload, and it will be closed when + // the channel object is closed. + res.fds.Init(rchannel1.FilePayload().Release()) + + // Save the channel. + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + c.channels = append(c.channels, res) + c.availableChannels = append(c.availableChannels, res) + return nil +} + // handleOne handles a single incoming message. // // This should only be called with the token from recvr. Note that the received @@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error { } } -// sendRecv performs a roundtrip message exchange. +// sendRecvLegacy performs a roundtrip message exchange. // // This is called by internal functions. -func (c *Client) sendRecv(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) error { tag, ok := c.tagPool.Get() if !ok { return ErrOutOfTags @@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error { return nil } +// sendRecvChannel uses channels to send a message. +func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. + c.channelsMu.Lock() + if len(c.availableChannels) == 0 { + c.channelsMu.Unlock() + return c.sendRecvLegacy(t, r) + } + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true + c.channelsWg.Add(1) + c.channelsMu.Unlock() + + // Ensure that it's connected. + if !ch.connected { + ch.connected = true + if err := ch.data.Connect(); err != nil { + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + return err + } + } + + // Send the message. + err := ch.sendRecv(c, t, r) + + // Release the channel. + c.channelsMu.Lock() + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.channelsMu.Unlock() + c.channelsWg.Done() + + return err +} + // Version returns the negotiated 9P2000.L.Google version number. func (c *Client) Version() uint32 { return c.version } -// Close closes the underlying socket. -func (c *Client) Close() error { - return c.socket.Close() +// Close closes the underlying socket and channels. +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 87b2dd61e..29a0afadf 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -35,23 +35,23 @@ func TestVersion(t *testing.T) { go s.Handle(serverSocket) // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString()) + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) if err != nil { t.Fatalf("got %v, expected nil", err) } // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN { + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { t.Errorf("got %v expected %v", err, syscall.EAGAIN) } @@ -60,3 +60,45 @@ func TestVersion(t *testing.T) { t.Errorf("got %v expected %v", err, syscall.EINVAL) } } + +func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + // See above. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // See above. + s := NewServer(nil) + go s.Handle(serverSocket) + + // See above. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + b.Fatalf("got %v, expected nil", err) + } + + // Initialize messages. + sendRecv := fn(c) + tversion := &Tversion{ + Version: versionString(highestSupportedVersion), + MSize: DefaultMessageSize, + } + rversion := new(Rversion) + + // Run in a loop. + for i := 0; i < b.N; i++ { + if err := sendRecv(tversion, rversion); err != nil { + b.Fatalf("got unexpected err: %v", err) + } + } +} + +func BenchmarkSendRecvLegacy(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) +} + +func BenchmarkSendRecvChannel(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 999b4f684..ba9a55d6d 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message { ref.opened = true ref.openFlags = t.Flags - return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile} + rlopen := &Rlopen{QID: qid, IoUnit: ioUnit} + rlopen.SetFilePayload(osFile) + return rlopen } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { @@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { // Replace the FID reference. cs.InsertFID(t.FID, newRef) - return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil + rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}} + rlcreate.SetFilePayload(osFile) + return rlcreate, nil } // handle implements handler.handle. @@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message { return newErr(err) } - return &Rlconnect{File: osFile} + rlconnect := &Rlconnect{} + rlconnect.SetFilePayload(osFile) + return rlconnect +} + +// handle implements handler.handle. +func (t *Tchannel) handle(cs *connState) message { + // Ensure that channels are enabled. + if err := cs.initializeChannels(); err != nil { + return newErr(err) + } + + // Lookup the given channel. + ch := cs.lookupChannel(t.ID) + if ch == nil { + return newErr(syscall.ENOSYS) + } + + // Return the payload. Note that we need to duplicate the file + // descriptor for the channel allocator, because sending is a + // destructive operation between sendRecvLegacy (and now the newer + // channel send operations). Same goes for the client FD. + rchannel := &Rchannel{ + Offset: uint64(ch.desc.Offset), + Length: uint64(ch.desc.Length), + } + switch t.Control { + case 0: + // Open the main data channel. + mfd, err := syscall.Dup(int(cs.channelAlloc.FD())) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(mfd)) + case 1: + cfd, err := syscall.Dup(ch.client.FD()) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(cfd)) + default: + return newErr(syscall.EINVAL) + } + return rchannel } diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index fd9eb1c5d..ffdd7e8c6 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -64,6 +64,21 @@ type filer interface { SetFilePayload(*fd.FD) } +// filePayload embeds a File object. +type filePayload struct { + File *fd.FD +} + +// FilePayload returns the file payload. +func (f *filePayload) FilePayload() *fd.FD { + return f.File +} + +// SetFilePayload sets the received file. +func (f *filePayload) SetFilePayload(file *fd.FD) { + f.File = file +} + // Tversion is a version request. type Tversion struct { // MSize is the message size to use. @@ -524,10 +539,7 @@ type Rlopen struct { // IoUnit is the recommended I/O unit. IoUnit uint32 - // File may be attached via the socket. - // - // This is an extension specific to this package. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType { return MsgRlopen } -// FilePayload returns the file payload. -func (r *Rlopen) FilePayload() *fd.FD { - return r.File -} - -// SetFilePayload sets the received file. -func (r *Rlopen) SetFilePayload(file *fd.FD) { - r.File = file -} - // String implements fmt.Stringer. func (r *Rlopen) String() string { return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) @@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string { // Rlconnect is a connect response. type Rlconnect struct { - // File is a host socket. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType { return MsgRlconnect } -// FilePayload returns the file payload. -func (r *Rlconnect) FilePayload() *fd.FD { - return r.File +// String implements fmt.Stringer. +func (r *Rlconnect) String() string { + return fmt.Sprintf("Rlconnect{File: %v}", r.File) } -// SetFilePayload sets the received file. -func (r *Rlconnect) SetFilePayload(file *fd.FD) { - r.File = file +// Tchannel creates a new channel. +type Tchannel struct { + // ID is the channel ID. + ID uint32 + + // Control is 0 if the Rchannel response should provide the flipcall + // component of the channel, and 1 if the Rchannel response should + // provide the fdchannel component of the channel. + Control uint32 +} + +// Decode implements encoder.Decode. +func (t *Tchannel) Decode(b *buffer) { + t.ID = b.Read32() + t.Control = b.Read32() +} + +// Encode implements encoder.Encode. +func (t *Tchannel) Encode(b *buffer) { + b.Write32(t.ID) + b.Write32(t.Control) +} + +// Type implements message.Type. +func (*Tchannel) Type() MsgType { + return MsgTchannel } // String implements fmt.Stringer. -func (r *Rlconnect) String() string { - return fmt.Sprintf("Rlconnect{File: %v}", r.File) +func (t *Tchannel) String() string { + return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control) +} + +// Rchannel is the channel response. +type Rchannel struct { + Offset uint64 + Length uint64 + filePayload +} + +// Decode implements encoder.Decode. +func (r *Rchannel) Decode(b *buffer) { + r.Offset = b.Read64() + r.Length = b.Read64() +} + +// Encode implements encoder.Encode. +func (r *Rchannel) Encode(b *buffer) { + b.Write64(r.Offset) + b.Write64(r.Length) +} + +// Type implements message.Type. +func (*Rchannel) Type() MsgType { + return MsgRchannel +} + +// String implements fmt.Stringer. +func (r *Rchannel) String() string { + return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } const maxCacheSize = 3 @@ -2356,4 +2409,6 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) + msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index e12831dbd..25530adca 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -378,6 +378,8 @@ const ( MsgRlconnect = 137 MsgTallocate = 138 MsgRallocate = 139 + MsgTchannel = 250 + MsgRchannel = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 6e939a49a..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index fe649c2e8..8bbdb2488 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) { } } } + +func TestReadWriteConcurrent(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const ( + instances = 10 + iterations = 10000 + dataSize = 1024 + ) + var ( + dataSets [instances][dataSize]byte + backends [instances]*Mock + files [instances]p9.File + ) + + // Walk to the file normally. + for i := 0; i < instances; i++ { + _, backends[i], files[i] = walkHelper(h, "file", root) + defer files[i].Close() + } + + // Open the files. + for i := 0; i < instances; i++ { + backends[i].EXPECT().Open(p9.ReadWrite) + if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + } + + // Initialize random data for each instance. + for i := 0; i < instances; i++ { + if _, err := rand.Read(dataSets[i][:]); err != nil { + t.Fatalf("error initializing dataSet#%d, got %v", i, err) + } + } + + // Define our random read/write mechanism. + randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { + // Prepare the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if n := copy(p, data); n != len(data) { + // Note that we have to assert the result here, as the Return statement + // below cannot be dynamic: it will be bound before this call is made. + h.t.Errorf("wanted length %d, got %d", len(data), n) + } + }).Return(len(data), nil) + + // Execute the read. + if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) + return // No sense doing check below. + } + if !bytes.Equal(test, data) { + t.Errorf("data integrity failed during read") // Not as expected. + } + } + randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { + // Prepare the backend. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if !bytes.Equal(p, data) { + h.t.Errorf("data integrity failed during write") // Not as expected. + } + }).Return(len(data), nil) + + // Execute the write. + if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) + } + } + randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { + test := make([]byte, len(data)) + for i := 0; i < n; i++ { + if rand.Intn(2) == 0 { + randRead(h, backend, f, data, test) + } else { + randWrite(h, backend, f, data) + } + } + } + + // Start reading and writing. + var wg sync.WaitGroup + for i := 0; i < instances; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) + }(i) + } + wg.Wait() +} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 95846e5f7..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } @@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { }() // Create the client. - client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString()) + client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) if err != nil { serverSocket.Close() clientSocket.Close() diff --git a/pkg/p9/server.go b/pkg/p9/server.go index b294efbb0..69c886a5d 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -21,6 +21,9 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -45,7 +48,6 @@ type Server struct { } // NewServer returns a new server. -// func NewServer(attacher Attacher) *Server { return &Server{ attacher: attacher, @@ -85,6 +87,8 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // -- below relates to the legacy handler -- + // recvOkay indicates that a receive may start. recvOkay chan bool @@ -93,6 +97,20 @@ type connState struct { // sendDone is signalled when a send is finished. sendDone chan error + + // -- below relates to the flipcall handler -- + + // channelMu protects below. + channelMu sync.Mutex + + // channelWg represents active workers. + channelWg sync.WaitGroup + + // channelAlloc allocates channel memory. + channelAlloc *flipcall.PacketWindowAllocator + + // channels are the set of initialized channels. + channels []*channel } // fidRef wraps a node and tracks references. @@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) { <-ch } +// initializeChannels initializes all channels. +// +// This is a no-op if channels are already initialized. +func (cs *connState) initializeChannels() (err error) { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + + // Initialize our channel allocator. + if cs.channelAlloc == nil { + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return err + } + cs.channelAlloc = alloc + } + + // Create all the channels. + for len(cs.channels) < channelsPerClient { + res := &channel{ + done: make(chan struct{}), + } + + res.desc, err = cs.channelAlloc.Allocate(channelSize) + if err != nil { + return err + } + if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil { + return err + } + + socks, err := fdchannel.NewConnectedSockets() + if err != nil { + res.data.Destroy() // Cleanup. + return err + } + res.fds.Init(socks[0]) + res.client = fd.New(socks[1]) + + cs.channels = append(cs.channels, res) + + // Start servicing the channel. + // + // When we call stop, we will close all the channels and these + // routines should finish. We need the wait group to ensure + // that active handlers are actually finished before cleanup. + cs.channelWg.Add(1) + go func() { // S/R-SAFE: Server side. + defer cs.channelWg.Done() + res.service(cs) + }() + } + + return nil +} + +// lookupChannel looks up the channel with given id. +// +// The function returns nil if no such channel is available. +func (cs *connState) lookupChannel(id uint32) *channel { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + if id >= uint32(len(cs.channels)) { + return nil + } + return cs.channels[id] +} + +// handle handles a single message. +func (cs *connState) handle(m message) (r message) { + defer func() { + if r == nil { + // Don't allow a panic to propagate. + recover() + + // Include a useful log message. + log.Warningf("panic in handler: %s", debug.Stack()) + + // Wrap in an EFAULT error; we don't really have a + // better way to describe this kind of error. It will + // usually manifest as a result of the test framework. + r = newErr(syscall.EFAULT) + } + }() + if handler, ok := m.(handler); ok { + // Call the message handler. + r = handler.handle(cs) + } else { + // Produce an ENOSYS error. + r = newErr(syscall.ENOSYS) + } + return +} + // handleRequest handles a single request. // // The recvDone channel is signaled when recv is done (with a error if @@ -428,41 +539,20 @@ func (cs *connState) handleRequest() { } // Handle the message. - var r message // r is the response. - defer func() { - if r == nil { - // Don't allow a panic to propagate. - recover() + r := cs.handle(m) - // Include a useful log message. - log.Warningf("panic in handler: %s", debug.Stack()) + // Clear the tag before sending. That's because as soon as this hits + // the wire, the client can legally send the same tag. + cs.ClearTag(tag) - // Wrap in an EFAULT error; we don't really have a - // better way to describe this kind of error. It will - // usually manifest as a result of the test framework. - r = newErr(syscall.EFAULT) - } + // Send back the result. + cs.sendMu.Lock() + err = send(cs.conn, tag, r) + cs.sendMu.Unlock() + cs.sendDone <- err - // Clear the tag before sending. That's because as soon as this - // hits the wire, the client can legally send another message - // with the same tag. - cs.ClearTag(tag) - - // Send back the result. - cs.sendMu.Lock() - err = send(cs.conn, tag, r) - cs.sendMu.Unlock() - cs.sendDone <- err - }() - if handler, ok := m.(handler); ok { - // Call the message handler. - r = handler.handle(cs) - } else { - // Produce an ENOSYS error. - r = newErr(syscall.ENOSYS) - } + // Return the message to the cache. msgRegistry.put(m) - m = nil // 'm' should not be touched after this point. } func (cs *connState) handleRequests() { @@ -477,7 +567,27 @@ func (cs *connState) stop() { close(cs.recvDone) close(cs.sendDone) - for _, fidRef := range cs.fids { + // Free the channels. + cs.channelMu.Lock() + for _, ch := range cs.channels { + ch.Shutdown() + } + cs.channelWg.Wait() + for _, ch := range cs.channels { + ch.Close() + } + cs.channels = nil // Clear. + cs.channelMu.Unlock() + + // Free the channel memory. + if cs.channelAlloc != nil { + cs.channelAlloc.Destroy() + } + + // Close all remaining fids. + for fid, fidRef := range cs.fids { + delete(cs.fids, fid) + // Drop final reference in the FID table. Note this should // always close the file, since we've ensured that there are no // handlers running via the wait for Pending => 0 below. @@ -510,7 +620,7 @@ func (cs *connState) service() error { for i := 0; i < pending; i++ { <-cs.sendDone } - return err + return nil } // This handler is now pending. diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 5648df589..6e8b4bbcd 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -54,7 +54,10 @@ const ( headerLength uint32 = 7 // maximumLength is the largest possible message. - maximumLength uint32 = 4 * 1024 * 1024 + maximumLength uint32 = 1 << 20 + + // DefaultMessageSize is a sensible default. + DefaultMessageSize uint32 = 64 << 10 // initialBufferLength is the initial data buffer we allocate. initialBufferLength uint32 = 64 diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..7cdf4ecc3 --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,263 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + active bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, syscall.EIO // Map everything to EIO. + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.Encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + n, err := ch.data.SendRecv(ssz) + if err != nil { + if n > 0 { + return n, nil + } + return 0, syscall.EIO // See above. + } + + return n, nil +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Copy from the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } + ch.buf.data = ch.buf.data[:fs] + } + + r.Decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return nil, syscall.Errno(rlerr.Error) + } + + return r, nil +} + +// sendRecv sends the given message over the channel. +// +// This is used by the client. +func (ch *channel) sendRecv(c *Client, m, r message) error { + rsz, err := ch.send(m) + if err != nil { + return err + } + _, err = ch.recv(r, rsz) + return err +} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index cdb3bc841..2f50ff3ea 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) { t.Fatalf("unable to create file: %v", err) } - if err := send(client, Tag(1), &Rlopen{File: f}); err != nil { + rlopen := &Rlopen{} + rlopen.SetFilePayload(f) + if err := send(client, Tag(1), rlopen); err != nil { t.Fatalf("send got err %v expected nil", err) } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index c2a2885ae..f1ffdd23a 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 7 + highestSupportedVersion uint32 = 8 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool { func versionSupportsTallocate(v uint32) bool { return v >= 7 } + +// versionSupportsFlipcall returns true if version v supports IPC channels from +// the flipcall package. Note that these must be negotiated, but this version +// string indicates that such a facility exists. +func versionSupportsFlipcall(v uint32) bool { + return v >= 8 +} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 697e7a2f4..078f084b2 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 9c08452fc..827385139 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "weak_ref_list", diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index d1024e49d..af94e944d 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") package(licenses = ["notice"]) diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index f38fb39f3..22abdc69f 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 694486296..12d7c77d2 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:private"], diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 53989301f..2d6379c86 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -8,5 +8,7 @@ package_group( packages = [ "//pkg/sentry/...", "//runsc/...", + # Code generated by go_marshal relies on go_marshal libraries. + "//tools/go_marshal/...", ], ) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 7aace2d7b..c71cff9f3 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -42,6 +43,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "registers_cc_proto", + visibility = ["//visibility:public"], + deps = [":registers_proto"], +) + go_proto_library( name = "registers_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index bf802d1b6..5522cecd0 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 7e8918722..0c86197f7 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "device", diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index d7259b47b..3119a61b6 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fs", diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index fbca06761..3cb73bd78 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { // Remove removes the given file or symlink. The root dirent is used to // resolve name, and must not be nil. -func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { +func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error { // Check the root. if root == nil { panic("Dirent.Remove: root must not be nil") @@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { return syscall.EISDIR + } else if dirPath { + return syscall.ENOTDIR } // Remove cannot remove a mount point. diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 884e3ff06..47bc72a88 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) { } d := f.Dirent - if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil { + if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { t.Fatalf("root.Remove(root, %q) failed: %v", name, err) } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index bf00b9c09..b9bd9ed17 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fdpipe", diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6499f87ac..b4ac83dc4 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "dirty_set_impl", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 6b993928c..2b71ca0e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "gofer", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index b1080fb1a..3e532332e 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "host", diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 08d7c0c57..5a7a5b8cd 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "lock_range", diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index c7599d1f6..1c93e8886 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "proc", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 20c3eefc8..76433c7d0 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "seqfile", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 516efcc4c..d0f351e5a 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "ramfs", diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..b03b7f836 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 8f7eb5757..11b680929 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tmpfs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5e9327aec..25811f668 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tty", @@ -23,6 +25,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 1d128532b..2f639c823 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { // Release implements fs.InodeOperations.Release. func (d *dirInodeOperations) Release(ctx context.Context) { + d.mu.Lock() + defer d.mu.Unlock() + d.master.DecRef() if len(d.slaves) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 92ec1ca18..19b7557d5 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme return 0, mf.t.ld.windowSize(ctx, io, args) case linux.TIOCSWINSZ: return 0, mf.t.ld.setWindowSize(ctx, io, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TCSETS, linux.TCSETSW, linux.TCSETSF, - linux.TIOCGPGRP, - linux.TIOCSPGRP, linux.TIOCGWINSZ, linux.TIOCSWINSZ, linux.TIOCSETD, @@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TIOCEXCL, linux.TIOCNXCL, linux.TIOCGEXCL, - linux.TIOCNOTTY, - linux.TIOCSCTTY, linux.TIOCGSID, linux.TIOCGETD, linux.TIOCVHANGUP, diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index e30266404..944c4ada1 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - // TODO(b/129283598): Implement once we have support for job - // control. - return 0, nil + return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index b7cecb2ed..ff8138820 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,7 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) // Terminal is a pseudoterminal. @@ -26,23 +29,100 @@ import ( type Terminal struct { refs.AtomicRefCount - // n is the terminal index. + // n is the terminal index. It is immutable. n uint32 - // d is the containing directory. + // d is the containing directory. It is immutable. d *dirInodeOperations - // ld is the line discipline of the terminal. + // ld is the line discipline of the terminal. It is immutable. ld *lineDiscipline + + // masterKTTY contains the controlling process of the master end of + // this terminal. This field is immutable. + masterKTTY *kernel.TTY + + // slaveKTTY contains the controlling process of the slave end of this + // terminal. This field is immutable. + slaveKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { termios := linux.DefaultSlaveTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{}, + slaveKTTY: &kernel.TTY{}, } t.EnableLeakCheck("tty.Terminal") return &t } + +// setControllingTTY makes tm the controlling terminal of the calling thread +// group. +func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setControllingTTY must be called from a task context") + } + + return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int()) +} + +// releaseControllingTTY removes tm as the controlling terminal of the calling +// thread group. +func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("releaseControllingTTY must be called from a task context") + } + + return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster)) +} + +// foregroundProcessGroup gets the process group ID of tm's foreground process. +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("foregroundProcessGroup must be called from a task context") + } + + ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster)) + if err != nil { + return 0, err + } + + // Write it out to *arg. + _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err +} + +// foregroundProcessGroup sets tm's foreground process. +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setForegroundProcessGroup must be called from a task context") + } + + // Read in the process group ID. + var pgid int32 + if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + + ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid)) + return uintptr(ret), err +} + +func (tm *Terminal) tty(isMaster bool) *kernel.TTY { + if isMaster { + return tm.masterKTTY + } + return tm.slaveKTTY +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 9e8ebb907..b0c286b7a 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") go_template_instance( diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 9fddb4c4c..bfc46dfa6 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index b51f3e18d..0b471d121 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba } if !cb.Handle(vfs.Dirent{ - Name: child.diskDirent.FileName(), - Type: fs.ToDirentType(childType), - Ino: uint64(child.diskDirent.Inode()), - Off: fd.off, + Name: child.diskDirent.FileName(), + Type: fs.ToDirentType(childType), + Ino: uint64(child.diskDirent.Inode()), + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 907d35b7e..2d50e30aa 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "disklayout", diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 63cf7aeaf..1aa2bd6a4 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) { // Ignore the inode number and offset of dirents because those are likely to // change as the underlying image changes. cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "Ino" || p.String() == "Off" + return p.String() == "Ino" || p.String() == "NextOff" }, cmp.Ignore()) if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { t.Errorf("dirents mismatch (-want +got):\n%s", diff) diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d2450e810..7e364c5fd 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c52dc781c..c620227c9 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 0 { if !cb.Handle(vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: vfsd.Impl().(*dentry).inode.ino, - Off: 0, + Name: ".", + Type: linux.DT_DIR, + Ino: vfsd.Impl().(*dentry).inode.ino, + NextOff: 1, }) { return nil } @@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 1 { parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode if !cb.Handle(vfs.Dirent{ - Name: "..", - Type: parentInode.direntType(), - Ino: parentInode.ino, - Off: 1, + Name: "..", + Type: parentInode.direntType(), + Ino: parentInode.ino, + NextOff: 2, }) { return nil } @@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Skip other directoryFD iterators. if child.inode != nil { if !cb.Handle(vfs.Dirent{ - Name: child.vfsd.Name(), - Type: child.inode.direntType(), - Ino: child.inode.ino, - Off: fd.off, + Name: child.vfsd.Name(), + Type: child.inode.direntType(), + Ino: child.inode.ino, + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3d8a4deaf..ade6ac946 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index f989f2f8b..359468ccc 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "hostcpu", srcs = [ "getcpu_amd64.s", + "getcpu_arm64.s", "hostcpu.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s new file mode 100644 index 000000000..caf9abb89 --- /dev/null +++ b/pkg/sentry/hostcpu/getcpu_arm64.s @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for +// the lack of an optimazed way of getting the current CPU number on arm64. + +// func GetCPU() (cpu uint32) +TEXT ·GetCPU(SB), NOSPLIT, $0-4 + MOVW ZR, cpu+0(FP) + MOVD $cpu+0(FP), R0 + MOVD $0x0, R1 // unused + MOVD $0x0, R2 // unused + MOVD $0xA8, R8 // SYS_GETCPU + SVC + RET diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e61d39c82..aba2414d4 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,9 +1,11 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "pending_signals_list", @@ -83,6 +85,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "uncaught_signal_cc_proto", + visibility = ["//visibility:public"], + deps = [":uncaught_signal_proto"], +) + go_proto_library( name = "uncaught_signal_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", @@ -144,6 +152,7 @@ go_library( "threads.go", "timekeeper.go", "timekeeper_state.go", + "tty.go", "uts_namespace.go", "vdso.go", "version.go", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index f46c43128..65427b112 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "epoll_list", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 1c5f979d4..983ca67ed 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "eventfd", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 6a31dc044..41f44999c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "atomicptr_bucket", diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index ebcfaa619..d7a7d1169 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -24,6 +25,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "memory_events_cc_proto", + visibility = ["//visibility:public"], + deps = [":memory_events_proto"], +) + go_proto_library( name = "memory_events_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4d15cca85..2ce8952e2 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "buffer_list", diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..93b50669f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..7c307f013 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "math" "syscall" @@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() { // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) + n, err := rw.Pipe.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return rw.Pipe.dup(ctx, ops) + } + n, err := rw.Pipe.read(ctx, ops) if n > 0 { rw.Pipe.Notify(waiter.EventOut) } @@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) if n > 0 { rw.Pipe.Notify(waiter.EventIn) } diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 1725b8562..98ea7a0d8 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 36edf10f3..80e5e5da3 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 81fcd8258..047b5214d 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -47,6 +47,11 @@ type Session struct { // The id is immutable. id SessionID + // foreground is the foreground process group. + // + // This is protected by TaskSet.mu. + foreground *ProcessGroup + // ProcessGroups is a list of process groups in this Session. This is // protected by TaskSet.mu. processGroups processGroupList @@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { func (tg *ThreadGroup) CreateSession() error { tg.pidns.owner.mu.Lock() defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() return tg.createSession() } // createSession creates a new session for a threadgroup. // -// Precondition: callers must hold TaskSet.mu for writing. +// Precondition: callers must hold TaskSet.mu and the signal mutex for writing. func (tg *ThreadGroup) createSession() error { // Get the ID for this thread in the current namespace. id := tg.pidns.tgids[tg] @@ -321,8 +328,14 @@ func (tg *ThreadGroup) createSession() error { childTG.processGroup.incRefWithParent(pg) childTG.processGroup.decRefWithParent(oldParentPG) }) - tg.processGroup.decRefWithParent(oldParentPG) + // If tg.processGroup is an orphan, decRefWithParent will lock + // the signal mutex of each thread group in tg.processGroup. + // However, tg's signal mutex may already be locked at this + // point. We change tg's process group before calling + // decRefWithParent to avoid locking tg's signal mutex twice. + oldPG := tg.processGroup tg.processGroup = pg + oldPG.decRefWithParent(oldParentPG) } else { // The current process group may be nil only in the case of an // unparented thread group (i.e. the init process). This would @@ -346,6 +359,9 @@ func (tg *ThreadGroup) createSession() error { ns.processGroups[ProcessGroupID(local)] = pg } + // Disconnect from the controlling terminal. + tg.tty = nil + return nil } diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD new file mode 100644 index 000000000..50b69d154 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -0,0 +1,22 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "signalfd", + srcs = ["signalfd.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fs/anon", + "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", + "//pkg/sentry/usermem", + "//pkg/syserror", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go new file mode 100644 index 000000000..06fd5ec88 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -0,0 +1,137 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package signalfd provides an implementation of signal file descriptors. +package signalfd + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/anon" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SignalOperations represent a file with signalfd semantics. +// +// +stateify savable +type SignalOperations struct { + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FilePipeSeek `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoFsync `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + // target is the original task target. + // + // The semantics here are a bit broken. Linux will always use current + // for all reads, regardless of where the signalfd originated. We can't + // do exactly that because we need to plumb the context through + // EventRegister in order to support proper blocking behavior. This + // will undoubtedly become very complicated quickly. + target *kernel.Task + + // mu protects below. + mu sync.Mutex `state:"nosave"` + + // mask is the signal mask. Protected by mu. + mask linux.SignalSet +} + +// New creates a new signalfd object with the supplied mask. +func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // No task context? Not valid. + return nil, syserror.EINVAL + } + // name matches fs/signalfd.c:signalfd4. + dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]") + return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{ + target: t, + mask: mask, + }), nil +} + +// Release implements fs.FileOperations.Release. +func (s *SignalOperations) Release() {} + +// Mask returns the signal mask. +func (s *SignalOperations) Mask() linux.SignalSet { + s.mu.Lock() + mask := s.mask + s.mu.Unlock() + return mask +} + +// SetMask sets the signal mask. +func (s *SignalOperations) SetMask(mask linux.SignalSet) { + s.mu.Lock() + s.mask = mask + s.mu.Unlock() +} + +// Read implements fs.FileOperations.Read. +func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { + // Attempt to dequeue relevant signals. + info, err := s.target.Sigtimedwait(s.Mask(), 0) + if err != nil { + // There must be no signal available. + return 0, syserror.ErrWouldBlock + } + + // Copy out the signal info using the specified format. + var buf [128]byte + binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + Signo: uint32(info.Signo), + Errno: info.Errno, + Code: info.Code, + PID: uint32(info.Pid()), + UID: uint32(info.Uid()), + Status: info.Status(), + Overrun: uint32(info.Overrun()), + Addr: info.Addr(), + }) + n, err := dst.CopyOut(ctx, buf[:]) + return int64(n), err +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return mask & waiter.EventIn +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) { + // Register for the signal set; ignore the passed events. + s.target.SignalRegister(entry, waiter.EventMask(s.Mask())) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SignalOperations) EventUnregister(entry *waiter.Entry) { + // Unregister the original entry. + s.target.SignalUnregister(entry) +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e91f82bb3..c82ef5486 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/third_party/gvsync" ) @@ -133,6 +134,13 @@ type Task struct { // signalStack is exclusive to the task goroutine. signalStack arch.SignalStack + // signalQueue is a set of registered waiters for signal-related events. + // + // signalQueue is protected by the signalMutex. Note that the task does + // not implement all queue methods, specifically the readiness checks. + // The task only broadcast a notification on signal delivery. + signalQueue waiter.Queue `state:"zerovalue"` + // If groupStopPending is true, the task should participate in a group // stop in the interrupt path. // diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 266959a07..39cd1340d 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -28,6 +28,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // SignalAction is an internal signal action. @@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) { // // Preconditions: The signal mutex must be locked. func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { + // Notify that the signal is queued. + t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig))) + // - Do not choose tasks that are blocking the signal. if linux.SignalSetOf(sig)&t.signalMask != 0 { return false @@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() return t.deliverSignal(info, act) } + +// SignalRegister registers a waiter for pending signals. +func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventRegister(e, mask) + t.tg.signalHandlers.mu.Unlock() +} + +// SignalUnregister unregisters a waiter for pending signals. +func (t *Task) SignalUnregister(e *waiter.Entry) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventUnregister(e) + t.tg.signalHandlers.mu.Unlock() +} diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index d60cd62c7..ae6fc4025 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { if parentPG := tg.parentPG(); parentPG == nil { tg.createSession() } else { - // Inherit the process group. + // Inherit the process group and terminal. parentPG.incRefWithParent(parentPG) tg.processGroup = parentPG + tg.tty = t.parent.tg.tty } } tg.tasks.PushBack(t) diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 2a97e3e8e..0eef24bfb 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,10 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" ) // A ThreadGroup is a logical grouping of tasks that has widespread @@ -245,6 +248,12 @@ type ThreadGroup struct { // // mounts is immutable. mounts *fs.MountNamespace + + // tty is the thread group's controlling terminal. If nil, there is no + // controlling terminal. + // + // tty is protected by the signal mutex. + tty *TTY } // newThreadGroup returns a new, empty thread group in PID namespace ns. The @@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) { } } +// SetControllingTTY sets tty as the controlling terminal of tg. +func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "The calling process must be a session leader and not have a + // controlling terminal already." - tty_ioctl(4) + if tg.processGroup.session.leader != tg || tg.tty != nil { + return syserror.EINVAL + } + + // "If this terminal is already the controlling terminal of a different + // session group, then the ioctl fails with EPERM, unless the caller + // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the + // terminal is stolen, and all processes that had it as controlling + // terminal lose it." - tty_ioctl(4) + if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { + if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + return syserror.EPERM + } + // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. + for othertg := range tg.pidns.owner.Root.tgids { + // This won't deadlock by locking tg.signalHandlers + // because at this point: + // - We only lock signalHandlers if it's in the same + // session as the tty's controlling thread group. + // - We know that the calling thread group is not in + // the same session as the tty's controlling thread + // group. + if othertg.processGroup.session == tty.tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + othertg.signalHandlers.mu.Unlock() + } + } + } + + // Set the controlling terminal and foreground process group. + tg.tty = tty + tg.processGroup.session.foreground = tg.processGroup + // Set this as the controlling process of the terminal. + tty.tg = tg + + return nil +} + +// ReleaseControllingTTY gives up tty as the controlling tty of tg. +func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.RLock() + defer tg.pidns.owner.mu.RUnlock() + + // Just below, we may re-lock signalHandlers in order to send signals. + // Thus we can't defer Unlock here. + tg.signalHandlers.mu.Lock() + + if tg.tty == nil || tg.tty != tty { + tg.signalHandlers.mu.Unlock() + return syserror.ENOTTY + } + + // "If the process was session leader, then send SIGHUP and SIGCONT to + // the foreground process group and all processes in the current + // session lose their controlling terminal." - tty_ioctl(4) + // Remove tty as the controlling tty for each process in the session, + // then send them SIGHUP and SIGCONT. + + // If we're not the session leader, we don't have to do much. + if tty.tg != tg { + tg.tty = nil + tg.signalHandlers.mu.Unlock() + return nil + } + + tg.signalHandlers.mu.Unlock() + + // We're the session leader. SIGHUP and SIGCONT the foreground process + // group and remove all controlling terminals in the session. + var lastErr error + for othertg := range tg.pidns.owner.Root.tgids { + if othertg.processGroup.session == tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + if othertg.processGroup == tg.processGroup.session.foreground { + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + lastErr = err + } + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + lastErr = err + } + } + othertg.signalHandlers.mu.Unlock() + } + } + + return lastErr +} + +// ForegroundProcessGroup returns the process group ID of the foreground +// process group. +func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "When fd does not refer to the controlling terminal of the calling + // process, -1 is returned" - tcgetpgrp(3) + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + return int32(tg.processGroup.session.foreground.id), nil +} + +// SetForegroundProcessGroup sets the foreground process group of tty to pgid. +func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // TODO(b/129283598): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is + // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all + // members of this background process group." + + // tty must be the controlling terminal. + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + // pgid must be positive. + if pgid < 0 { + return -1, syserror.EINVAL + } + + // pg must not be empty. Empty process groups are removed from their + // pid namespaces. + pg, ok := tg.pidns.processGroups[pgid] + if !ok { + return -1, syserror.ESRCH + } + + // pg must be part of this process's session. + if tg.processGroup.session != pg.session { + return -1, syserror.EPERM + } + + tg.processGroup.session.foreground.id = pgid + return 0, nil +} + // itimerRealListener implements ktime.Listener for ITIMER_REAL expirations. // // +stateify savable diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go new file mode 100644 index 000000000..34f84487a --- /dev/null +++ b/pkg/sentry/kernel/tty.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync" + +// TTY defines the relationship between a thread group and its controlling +// terminal. +// +// +stateify savable +type TTY struct { + mu sync.Mutex `state:"nosave"` + + // tg is protected by mu. + tg *ThreadGroup +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 40025d62d..59649c770 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "limits", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 29c14ec56..9687e7e76 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "mappable_range", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 072745a08..b35c8c673 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "file_refcount_set", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 858f895f2..3fd904c67 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "evictable_range", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index eeb634644..b6d008dbe 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index fe979dccf..31fa48ec5 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 47957bb3b..72c7ec564 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) { cpu: ^uint32(0), }, nil } + +// getEventMessage retrieves a message about the ptrace event that just happened. +func (t *thread) getEventMessage() (uintptr, error) { + var msg uintptr + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETEVENTMSG, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(&msg)), + 0, 0) + if errno != 0 { + return msg, errno + } + return msg, nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 6bf7cd097..4f8f9c5d9 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - t.dumpAndPanic("wait failed: the process exited") + msg, err := t.getEventMessage() + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) @@ -426,6 +427,9 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { break } else { // Some other signal caused a thread stop; ignore. + if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD { + log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig) + } continue } } diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 3b95af617..ea090b686 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 924d8a6d6..6769cd0a5 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index fd6dc8e6e..884020f7b 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 0e37ce61b..3e66f9cbb 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -26,6 +26,7 @@ package epsocket import ( "bytes" + "io" "math" "reflect" "sync" @@ -208,6 +209,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -227,7 +232,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -412,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -431,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -469,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -777,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -793,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1165,7 +1279,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1173,7 +1287,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -2060,7 +2174,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2085,7 +2199,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2207,9 +2321,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..445080aa4 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "port", diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 5061dcbde..3a6baa308 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -49,6 +50,14 @@ proto_library( ], ) +cc_proto_library( + name = "syscall_rpc_cc_proto", + visibility = [ + "//visibility:public", + ], + deps = [":syscall_rpc_proto"], +) + go_proto_library( name = "syscall_rpc_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2b0ad6395..1867b3a5c 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 445d25010..7d7b42eba 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -44,6 +45,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "strace_cc_proto", + visibility = ["//visibility:public"], + deps = [":strace_proto"], +) + go_proto_library( name = "strace_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 3650fd6e1..5d57b75af 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 33a40b9c6..e76ee27d2 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", + "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/memmap", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..18d24ab61 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -320,21 +320,21 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) 280: syscalls.Supported("utimensat", Utimensat), 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 283: syscalls.Supported("timerfd_create", TimerfdCreate), 284: syscalls.Supported("eventfd", Eventfd), 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), 286: syscalls.Supported("timerfd_settime", TimerfdSettime), 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 290: syscalls.Supported("eventfd2", Eventfd2), 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 2e00a91ce..b9a8e3e21 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { if err != nil { return err } - if dirPath { - return syserror.ENOENT - } return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { if !fs.IsDir(d.Inode.StableAttr) { @@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return err } - return d.Remove(t, root, name) + return d.Remove(t, root, name, dirPath) }) } diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 0104a94c0..fb6efd5d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -20,7 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" + "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?") return 0, nil, syserror.EINTR } + +// sharedSignalfd is shared between the two calls. +func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { + // Copy in the signal mask. + mask, err := copyInSigSet(t, sigset, sigsetsize) + if err != nil { + return 0, nil, err + } + + // Always check for valid flags, even if not creating. + if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Is this a change to an existing signalfd? + // + // The spec indicates that this should adjust the mask. + if fd != -1 { + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Is this a signalfd? + if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { + s.SetMask(mask) + return 0, nil, nil + } + + // Not a signalfd. + return 0, nil, syserror.EINVAL + } + + // Create a new file. + file, err := signalfd.New(t, mask) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + // Set appropriate flags. + file.SetFlags(fs.SettableFileFlags{ + NonBlocking: flags&linux.SFD_NONBLOCK != 0, + }) + + // Create a new descriptor. + fd, err = t.NewFDFrom(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.SFD_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + // Done. + return uintptr(fd), nil, nil +} + +// Signalfd implements the linux syscall signalfd(2). +func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + return sharedSignalfd(t, fd, sigset, sigsetsize, 0) +} + +// Signalfd4 implements the linux syscall signalfd4(2). +func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + flags := args[3].Int() + return sharedSignalfd(t, fd, sigset, sigsetsize, flags) +} diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 8a98fedcb..f0a292f2f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -159,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) } // We can only pass a single file to handleIOError, so pick inFile @@ -181,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -200,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -257,7 +255,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -275,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -301,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 8aa6a3017..beb43ba13 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index b69603da3..fc7614fff 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -10,6 +11,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "unimplemented_syscall_cc_proto", + visibility = ["//visibility:public"], + deps = [":unimplemented_syscall_proto"], +) + go_proto_library( name = "unimplemented_syscall_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index a5b4206bb..cc5d25762 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "addr_range", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0f247bf77..eff4b44f6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 86bde7fb3..7eb2b2821 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -199,8 +199,11 @@ type Dirent struct { // Ino is the inode number. Ino uint64 - // Off is this Dirent's offset. - Off int64 + // NextOff is the offset of the *next* Dirent in the directory; that is, + // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will + // cause the next call to FileDescription.IterDirents() to yield the next + // Dirent. (The offset of the first Dirent in a directory is always 0.) + NextOff int64 } // IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents. diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index 00665c939..bdca80d37 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index c0f3c658d..329904457 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e70f4a79f..8a865d229 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index b149f9e02..bd3f9fd28 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index df37c7d5a..3fd9e3134 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tcpip", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 0d2637ee4..78df5a0b1 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 3301967fb..b4e8d6810 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "buffer", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 29b30be9c..0c5c20cea 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 76ef02f13..b558350c3 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "header", diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 093850e25..9d3abc0e4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,6 +76,13 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 @@ -221,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool { return addr[0] == 0xff } +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr[0] != 0xff +} + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..18adb2085 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -44,14 +44,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. @@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 74fbbb896..8fa9e3984 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..584db710e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,6 +41,7 @@ package fdbased import ( "fmt" + "sync" "syscall" "golang.org/x/sys/unix" @@ -81,6 +82,7 @@ const ( PacketMMap ) +// An endpoint implements the link-layer using a message-oriented file descriptor. type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -114,6 +116,9 @@ type endpoint struct { // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is // disabled. gsoMaxSize uint32 + + // wg keeps track of running goroutines. + wg sync.WaitGroup } // Options specify the details about the fd-based endpoint to be created. @@ -164,8 +169,9 @@ type Options struct { // New creates a new fd-based endpoint. // // Makes fd non-blocking, but does not take ownership of fd, which must remain -// open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +196,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,12 +213,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { @@ -222,12 +228,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -290,7 +296,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // saved, they stop sending outgoing packets and all incoming packets // are rejected. for i := range e.inboundDispatchers { - go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) } } @@ -320,6 +330,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + // virtioNetHdr is declared in linux/virtio_net.h. type virtioNetHdr struct { flags uint8 @@ -435,14 +451,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..04406bc9a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..b36629d2c 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -32,8 +32,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- @@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index ea12ef1ac..1bab380b0 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..7c946101d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) * return endpoint.WriteRawPacket(dest, packet) } +// Wait implements stack.LinkEndpoint.Wait. +func (m *InjectableEndpoint) Wait() { + for _, ep := range m.routes { + ep.Wait() + } +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..3086fec00 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f2998aa98..0a5ea3dc4 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 94725cb11..330ed5e94 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 160a8f864..de1ce043d 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..9e71d4edf 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. @@ -132,7 +132,8 @@ func (e *endpoint) Close() { } } -// Wait waits until all workers have stopped after a Close() call. +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. func (e *endpoint) Wait() { e.completed.Wait() } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..0e9ba0846 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..e401dce44 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -58,10 +58,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.lower.WritePacket(r, gso, hdr, payload, protocol) } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 2597d4b3e..0746dc8ec 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..5a1791cb5 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,11 +40,10 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. @@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() { func (e *Endpoint) WaitDispatch() { e.dispatchGate.Close() } + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..ae23c96b7 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P return nil } +// Wait implements stack.LinkEndpoint.Wait. +func (*countedEndpoint) Wait() {} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) @@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..df0d3a8c0 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..387fca96e 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -47,11 +47,13 @@ func newTestContext(t *testing.T) *testContext { s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +75,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..c5c7aad86 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "reassembler_list", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4b3bd74fa..6a40e7ee3 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..58e537aad 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..ae827ca27 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,11 +36,11 @@ func TestExcludeBroadcast(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -184,15 +184,12 @@ type errorChannel struct { // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. @@ -242,9 +239,8 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +262,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..f06622a8b 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -25,6 +26,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -36,6 +38,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..653d984e9 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -83,8 +83,7 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li func TestICMPCounts(t *testing.T) { s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -211,36 +210,27 @@ func newTestContext(t *testing.T) *testContext { } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..57bcd5455 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,258 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveICMP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, nil, stack.Options{}) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..571915d3f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,15 +32,14 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Helper() s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 989058413..11efb4e44 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index e2021cd15..f12189580 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -138,11 +138,11 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) + linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil { + if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 1716be285..329941775 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -128,7 +128,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{ + linkEP, err := fdbased.New(&fdbased.Options{ FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, @@ -137,7 +137,7 @@ func main() { if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, linkID); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 788de3dfe..28c49e8ff 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "linkaddrentry_list", diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index f8156be47..3a20839da 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "golang.org/x/time/rate" ) @@ -33,54 +31,11 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - mu sync.RWMutex - l *rate.Limiter + *rate.Limiter } // NewICMPRateLimiter returns a global rate limiter for controlling the rate // at which ICMP messages are generated by the stack. func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} -} - -// Allow returns true if we are allowed to send at least 1 message at the -// moment. -func (i *ICMPRateLimiter) Allow() bool { - i.mu.RLock() - allow := i.l.Allow() - i.mu.RUnlock() - return allow -} - -// Limit returns the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) Limit() rate.Limit { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Limit() -} - -// SetLimit sets the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { - i.mu.RLock() - defer i.mu.RUnlock() - i.l.SetLimit(newLimit) -} - -// Burst returns how many ICMP messages can be sent at any single instant. -func (i *ICMPRateLimiter) Burst() int { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Burst() -} - -// SetBurst sets the maximum number of ICMP messages allowed at any single -// instant. -// -// NOTE: Changing Burst causes the underlying rate limiter to be recreated. -func (i *ICMPRateLimiter) SetBurst(burst int) { - i.mu.Lock() - i.l = rate.NewLimiter(i.l.Limit(), burst) - i.mu.Unlock() + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43719085e..a719058b4 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -307,6 +326,8 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + // Sanity check. id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if _, ok := n.endpoints[id]; ok { @@ -339,6 +360,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -467,13 +497,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.getKind() != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } r.setKind(permanentExpired) - r.decRefLocked() + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -491,6 +535,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -518,6 +569,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -802,11 +860,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 67b70b2ee..07e4c770d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -297,6 +295,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -379,10 +386,6 @@ var ( networkProtocols = make(map[string]NetworkProtocolFactory) unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) ) // RegisterTransportProtocolFactory registers a new transport protocol factory @@ -406,28 +409,6 @@ func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { unassociatedFactory = f } -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6beca6ae8..1fe21b68e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -620,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -638,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -685,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c6a8160af..0c26c9911 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -225,9 +225,9 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +245,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +255,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +265,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +274,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +284,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -307,59 +307,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +369,9 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,7 +388,7 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -397,8 +397,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +410,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -478,8 +478,8 @@ func TestRoutes(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +491,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -556,8 +556,8 @@ func TestAddressRemoval(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +578,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -601,9 +601,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +626,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -690,8 +690,8 @@ func TestEndpointExpiration(t *testing.T) { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -724,15 +724,15 @@ func TestEndpointExpiration(t *testing.T) { //----------------------- verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. @@ -741,8 +741,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. @@ -752,15 +752,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. @@ -769,8 +769,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,9 +779,9 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -791,16 +791,16 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. @@ -809,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -828,15 +828,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -846,8 +846,8 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +867,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +886,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -896,8 +896,8 @@ func TestSpoofingWithAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -936,8 +936,8 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -951,7 +951,7 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { @@ -960,8 +960,8 @@ func TestSpoofingNoAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +980,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -999,14 +999,14 @@ func TestSpoofingNoAddress(t *testing.T) { } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -1076,8 +1076,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1132,8 +1132,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1159,7 +1159,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { @@ -1198,8 +1198,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1236,8 +1236,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,7 +1262,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { @@ -1320,8 +1320,8 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1361,8 +1361,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1426,8 +1426,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1501,8 +1501,8 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1526,8 +1526,8 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1558,8 +1558,8 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1587,8 +1587,8 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1621,8 +1621,8 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1639,7 +1639,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1653,9 +1653,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1669,16 +1669,16 @@ func TestNICForwarding(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1697,10 +1697,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ca185279e..0e69ac7c8 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -278,9 +283,9 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +345,9 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +413,9 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -497,16 +502,16 @@ func TestTransportForwarding(t *testing.T) { s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +550,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +564,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 418e771d2..c021c67ac 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -261,31 +261,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -338,7 +341,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -398,6 +401,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -432,16 +439,33 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -450,18 +474,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int @@ -600,9 +612,6 @@ func (r Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..a111fdb2a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..a02731a5d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() return 0, nil, err } @@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 // destination address, route using that address. if !ep.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { + if !ip.IsValid(len(payloadBytes)) { ep.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } @@ -493,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -505,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } ep.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSize + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption @@ -516,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 1ee1a53f8..39a839ab7 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "tcp_segment_list", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..35b489c68 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. @@ -946,62 +952,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1065,6 +1018,67 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -1176,6 +1190,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -1198,18 +1224,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..9fa97528b 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 32bb45224..7fa5cfb6e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) @@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -469,7 +469,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -557,8 +557,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -631,7 +630,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -700,7 +699,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -785,7 +784,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -804,8 +803,7 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -872,11 +870,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -976,7 +972,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1017,7 +1013,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1075,8 +1071,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1110,8 +1105,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1151,7 +1145,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1224,7 +1218,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1293,8 +1287,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1399,7 +1392,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1449,7 +1442,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1497,7 +1490,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1579,7 +1572,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1695,7 +1688,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -1704,7 +1697,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1727,7 +1720,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1871,7 +1864,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1973,7 +1966,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2010,7 +2003,7 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2035,7 +2028,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2078,7 +2071,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2132,7 +2125,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2203,7 +2196,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2291,7 +2284,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2377,7 +2370,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2509,7 +2502,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2559,7 +2552,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2580,7 +2573,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2604,7 +2597,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2635,7 +2628,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2681,7 +2674,7 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -2856,8 +2849,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2869,8 +2862,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2945,26 +2938,26 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -3231,7 +3224,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3308,7 +3301,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3482,7 +3475,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 18c707a57..78eff5c3a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -150,11 +150,12 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +181,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } @@ -511,7 +512,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -589,7 +590,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -597,8 +598,8 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 4bec48c0f..43fcc27f0 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index ac2666f69..c1ca22b35 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "udp_packet_list", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 66455ef46..0bec7e62d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { @@ -391,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: @@ -570,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -580,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -747,6 +751,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } e.state = StateBound } else { + if e.id.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } e.state = StateInitial } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 995d6e8a1..c6deab892 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -275,12 +275,13 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +307,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 98d51cc69..6afdb29b7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index cbd92fc05..8f6f180e5 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b7f505a84..b6bbb0ea2 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 9173dfd0f..8dc88becb 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", |