summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/msgqueue.go2
-rw-r--r--pkg/merkletree/merkletree.go91
-rw-r--r--pkg/p9/file.go16
-rw-r--r--pkg/sentry/control/BUILD8
-rw-r--r--pkg/sentry/control/control.proto40
-rw-r--r--pkg/sentry/devices/quotedev/BUILD16
-rw-r--r--pkg/sentry/devices/quotedev/quotedev.go52
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go41
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go131
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD14
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go47
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go261
-rw-r--r--pkg/sentry/kernel/ipc/object.go35
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go109
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go12
-rw-r--r--pkg/sentry/kernel/shm/shm.go19
-rw-r--r--pkg/sentry/kernel/thread_group.go15
-rw-r--r--pkg/sentry/mm/aio_context.go18
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go6
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_msgqueue.go53
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go19
-rw-r--r--pkg/sentry/vfs/README.md4
-rw-r--r--pkg/tcpip/checker/checker.go2
-rw-r--r--pkg/tcpip/header/eth.go6
-rw-r--r--pkg/tcpip/header/eth_test.go4
-rw-r--r--pkg/tcpip/link/channel/channel.go3
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go28
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go5
-rw-r--r--pkg/tcpip/link/nested/nested.go5
-rw-r--r--pkg/tcpip/link/pipe/pipe.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go5
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go5
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go5
-rw-r--r--pkg/tcpip/stack/forwarding_test.go4
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go11
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go64
-rw-r--r--pkg/tcpip/transport/tcp/connect.go112
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go50
-rw-r--r--pkg/tcpip/transport/tcp/rack.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go9
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go289
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go81
58 files changed, 1362 insertions, 477 deletions
diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go
index e1e8d0357..0612a8214 100644
--- a/pkg/abi/linux/msgqueue.go
+++ b/pkg/abi/linux/msgqueue.go
@@ -47,7 +47,7 @@ const (
MSGSSZ = 16
// MSGSEG is simplified due to the inexistance of a ternary operator.
- MSGSEG = (MSGPOOL * 1024) / MSGSSZ
+ MSGSEG = 0xffff
)
// MsqidDS is equivelant to struct msqid64_ds. Source:
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 0b961d3d9..6358ad8e9 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -384,6 +384,14 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
return descriptor.verify(params.Expected, params.HashAlgorithms)
}
+// cachedHashes stores verified hashes from a previous hash step.
+type cachedHashes struct {
+ // offset is the offset of cached hash in each level.
+ offset []int64
+ // hash is the verified cache for each level from previous hash steps.
+ hash [][]byte
+}
+
// Verify verifies the content read from data with offset. The content is
// verified against tree. If content spans across multiple blocks, each block is
// verified. Verification fails if the hash of the data does not match the tree
@@ -409,29 +417,32 @@ func Verify(params *VerifyParams) (int64, error) {
firstDataBlock := params.ReadOffset / layout.blockSize
lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize
- buf := make([]byte, layout.blockSize)
- var readErr error
- total := int64(0)
+ size := (lastDataBlock - firstDataBlock + 1) * layout.blockSize
+ retBuf := make([]byte, size)
+ n, err := params.File.ReadAt(retBuf, firstDataBlock*layout.blockSize)
+ if err != nil && err != io.EOF {
+ return 0, err
+ }
+ total := int64(n)
+ bytesRead := int64(0)
+
+ // Only cache hash results if reading more than a block.
+ var ch *cachedHashes
+ if lastDataBlock > firstDataBlock {
+ ch = &cachedHashes{
+ offset: make([]int64, layout.numLevels()),
+ hash: make([][]byte, layout.numLevels()),
+ }
+ }
for i := firstDataBlock; i <= lastDataBlock; i++ {
+ // Reach the end of file during verification.
+ if total <= 0 {
+ return bytesRead, io.EOF
+ }
// Read a block that includes all or part of target range in
// input data.
- bytesRead, err := params.File.ReadAt(buf, i*layout.blockSize)
- readErr = err
- // If at the end of input data and all previous blocks are
- // verified, return the verified input data and EOF.
- if readErr == io.EOF && bytesRead == 0 {
- break
- }
- if readErr != nil && readErr != io.EOF {
- return 0, fmt.Errorf("read from data failed: %w", err)
- }
- // If this is the end of file, zero the remaining bytes in buf,
- // otherwise they are still from the previous block.
- if bytesRead < len(buf) {
- for j := bytesRead; j < len(buf); j++ {
- buf[j] = 0
- }
- }
+ buf := retBuf[(i-firstDataBlock)*layout.blockSize : (i-firstDataBlock+1)*layout.blockSize]
+
descriptor := VerityDescriptor{
Name: params.Name,
FileSize: params.Size,
@@ -441,8 +452,8 @@ func Verify(params *VerifyParams) (int64, error) {
SymlinkTarget: params.SymlinkTarget,
Children: params.Children,
}
- if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
- return 0, err
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected, ch); err != nil {
+ return bytesRead, err
}
// startOff is the beginning of the read range within the
@@ -459,22 +470,24 @@ func Verify(params *VerifyParams) (int64, error) {
if i == lastDataBlock {
endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1
}
+
// If the provided size exceeds the end of input data, we should
// only copy the parts in buf that's part of input data.
- if startOff > int64(bytesRead) {
- startOff = int64(bytesRead)
+ if startOff > total {
+ startOff = total
}
- if endOff > int64(bytesRead) {
- endOff = int64(bytesRead)
+ if endOff > total {
+ endOff = total
}
+
n, err := params.Out.Write(buf[startOff:endOff])
if err != nil {
- return total, err
+ return bytesRead, err
}
- total += int64(n)
-
+ bytesRead += int64(n)
+ total -= endOff
}
- return total, readErr
+ return bytesRead, nil
}
// verifyBlock verifies a block against tree. index is the number of block in
@@ -482,7 +495,7 @@ func Verify(params *VerifyParams) (int64, error) {
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
// expected.
-func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error {
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte, ch *cachedHashes) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -491,6 +504,12 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
treeBlock := make([]byte, layout.blockSize)
var digest []byte
for level := 0; level < layout.numLevels(); level++ {
+ // No need to verify remaining levels if the current block has
+ // been verified in a previous call and cached.
+ if ch != nil && ch.offset[level] == layout.digestOffset(level, blockIndex) && ch.hash[level] != nil {
+ break
+ }
+
// Calculate hash.
if level == 0 {
h, err := hashData(dataBlock, hashAlgorithms)
@@ -521,11 +540,19 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
if !bytes.Equal(digest, expectedDigest) {
return fmt.Errorf("verification failed")
}
+ if ch != nil {
+ ch.offset[level] = layout.digestOffset(level, blockIndex)
+ ch.hash[level] = expectedDigest
+ }
blockIndex = blockIndex / layout.hashesPerBlock()
}
// Verification for the tree succeeded. Now hash the descriptor with
// the root hash and compare it with expected.
- descriptor.RootHash = digest
+ if ch != nil {
+ descriptor.RootHash = ch.hash[layout.rootLevel()]
+ } else {
+ descriptor.RootHash = digest
+ }
return descriptor.verify(expected, hashAlgorithms)
}
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 97e0231d6..8d6af2d6b 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -325,6 +325,12 @@ func (*DisallowServerCalls) Renamed(File, string) {
func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
stats := make([]FullStat, 0, len(names))
parent := start
+ closeParent := func() {
+ if parent != start {
+ _ = parent.Close()
+ }
+ }
+ defer closeParent()
mask := AttrMaskAll()
for i, name := range names {
if len(name) == 0 && i == 0 {
@@ -340,15 +346,14 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
continue
}
qids, child, valid, attr, err := parent.WalkGetAttr([]string{name})
- if parent != start {
- _ = parent.Close()
- }
if err != nil {
if errors.Is(err, unix.ENOENT) {
return stats, nil
}
return nil, err
}
+ closeParent()
+ parent = child
stats = append(stats, FullStat{
QID: qids[0],
Valid: valid,
@@ -357,13 +362,8 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
if attr.Mode.FileType() != ModeDirectory {
// Doesn't need to continue if entry is not a dir. Including symlinks
// that cannot be followed.
- _ = child.Close()
break
}
- parent = child
- }
- if parent != start {
- _ = parent.Close()
}
return stats, nil
}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index a4934a565..cfb33a398 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,7 +1,13 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
package(licenses = ["notice"])
+proto_library(
+ name = "control",
+ srcs = ["control.proto"],
+ visibility = ["//visibility:public"],
+)
+
go_library(
name = "control",
srcs = [
diff --git a/pkg/sentry/control/control.proto b/pkg/sentry/control/control.proto
new file mode 100644
index 000000000..72dda3fbc
--- /dev/null
+++ b/pkg/sentry/control/control.proto
@@ -0,0 +1,40 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+// ControlConfig configures the permission of controls.
+message ControlConfig {
+ // Names for individual control URPC service objects.
+ // Any new service object that should be given conditional access should be
+ // named here and conditionally added based on presence in allowed_controls.
+ enum Endpoint {
+ UNKNOWN = 0;
+ EVENTS = 1;
+ FS = 2;
+ LIFECYCLE = 3;
+ LOGGING = 4;
+ PROFILE = 5;
+ USAGE = 6;
+ PROC = 7;
+ STATE = 8;
+ DEBUG = 9;
+ }
+
+ // allowed_controls represents which endpoints may be registered to the
+ // server.
+ repeated Endpoint allowed_controls = 1;
+}
diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD
deleted file mode 100644
index ee946610a..000000000
--- a/pkg/sentry/devices/quotedev/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-licenses(["notice"])
-
-go_library(
- name = "quotedev",
- srcs = ["quotedev.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/context",
- "//pkg/errors/linuxerr",
- "//pkg/sentry/fsimpl/devtmpfs",
- "//pkg/sentry/vfs",
- ],
-)
diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go
deleted file mode 100644
index 140856a4a..000000000
--- a/pkg/sentry/devices/quotedev/quotedev.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package quotedev implements a vfs.Device for /dev/gvisor_quote.
-package quotedev
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-const (
- quoteDevMinor = 0
-)
-
-// quoteDevice implements vfs.Device for /dev/gvisor_quote
-//
-// +stateify savable
-type quoteDevice struct{}
-
-// Open implements vfs.Device.Open.
-// TODO(b/157161182): Add support for attestation ioctls.
-func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return nil, linuxerr.EIO
-}
-
-// Register registers all devices implemented by this package in vfsObj.
-func Register(vfsObj *vfs.VirtualFilesystem) error {
- return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{
- GroupName: "gvisor_quote",
- })
-}
-
-// CreateDevtmpfsFiles creates device special files in dev representing all
-// devices implemented by this package.
-func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
- return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */)
-}
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 5c57f6fea..02540a754 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sync"
)
// handle represents a remote "open file descriptor", consisting of an opened
@@ -130,3 +131,43 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
}
return uint64(n), cperr
}
+
+type handleReadWriter struct {
+ ctx context.Context
+ h *handle
+ off uint64
+}
+
+var handleReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &handleReadWriter{}
+ },
+}
+
+func getHandleReadWriter(ctx context.Context, h *handle, offset int64) *handleReadWriter {
+ rw := handleReadWriterPool.Get().(*handleReadWriter)
+ rw.ctx = ctx
+ rw.h = h
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putHandleReadWriter(rw *handleReadWriter) {
+ rw.ctx = nil
+ rw.h = nil
+ handleReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *handleReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := rw.h.readToBlocksAt(rw.ctx, dsts, rw.off)
+ rw.off += n
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *handleReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ n, err := rw.h.writeFromBlocksAt(rw.ctx, srcs, rw.off)
+ rw.off += n
+ return n, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index e67422a2f..8dcbc61ed 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -158,6 +158,10 @@ func (d *dentryPlatformFile) afterLoad() {
// afterLoad is invoked by stateify.
func (fd *specialFileFD) afterLoad() {
fd.handle.fd = -1
+ if fd.hostFileMapper.IsInited() {
+ // Ensure that we don't call fd.hostFileMapper.Init() again.
+ fd.hostFileMapperInitOnce.Do(func() {})
+ }
}
// CompleteRestore implements
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 144a1045e..a8d47b65b 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -22,10 +22,13 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -75,6 +78,16 @@ type specialFileFD struct {
bufMu sync.Mutex `state:"nosave"`
haveBuf uint32
buf []byte
+
+ // If handle.fd >= 0, hostFileMapper caches mappings of handle.fd, and
+ // hostFileMapperInitOnce is used to initialize it on first use.
+ hostFileMapperInitOnce sync.Once `state:"nosave"`
+ hostFileMapper fsutil.HostFileMapper
+
+ // If handle.fd >= 0, fileRefs counts references on memmap.File offsets.
+ // fileRefs is protected by fileRefsMu.
+ fileRefsMu sync.Mutex `state:"nosave"`
+ fileRefs fsutil.FrameRefSet
}
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) {
@@ -229,23 +242,13 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
}
}
- // Going through dst.CopyOutFrom() would hold MM locks around file
- // operations of unknown duration. For regularFileFD, doing so is necessary
- // to support mmap due to lock ordering; MM locks precede dentry.dataMu.
- // That doesn't hold here since specialFileFD doesn't client-cache data.
- // Just buffer the read instead.
- buf := make([]byte, dst.NumBytes())
- n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ rw := getHandleReadWriter(ctx, &fd.handle, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putHandleReadWriter(rw)
if linuxerr.Equals(linuxerr.EAGAIN, err) {
err = linuxerr.ErrWouldBlock
}
- if n == 0 {
- return bufN, err
- }
- if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
- return bufN + int64(cp), cperr
- }
- return bufN + int64(n), err
+ return bufN + n, err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -316,20 +319,15 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
}
- // Do a buffered write. See rationale in PRead.
- buf := make([]byte, src.NumBytes())
- copied, copyErr := src.CopyIn(ctx, buf)
- if copied == 0 && copyErr != nil {
- // Only return the error if we didn't get any data.
- return 0, offset, copyErr
- }
- n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset))
+ rw := getHandleReadWriter(ctx, &fd.handle, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ putHandleReadWriter(rw)
if linuxerr.Equals(linuxerr.EAGAIN, err) {
err = linuxerr.ErrWouldBlock
}
// Update offset if the offset is valid.
if offset >= 0 {
- offset += int64(n)
+ offset += n
}
// Update file size for regular files.
if fd.isRegularFile {
@@ -340,10 +338,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
atomic.StoreUint64(&d.size, uint64(offset))
}
}
- if err != nil {
- return int64(n), offset, err
- }
- return int64(n), offset, copyErr
+ return int64(n), offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -411,3 +406,85 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error
}
return nil
}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *specialFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if fd.handle.fd < 0 || fd.filesystem().opts.forcePageCache {
+ return linuxerr.ENODEV
+ }
+ // After this point, fd may be used as a memmap.Mappable and memmap.File.
+ fd.hostFileMapperInitOnce.Do(fd.hostFileMapper.Init)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (fd *specialFileFD) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
+ fd.hostFileMapper.IncRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())})
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (fd *specialFileFD) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
+ fd.hostFileMapper.DecRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())})
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (fd *specialFileFD) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (fd *specialFileFD) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
+ mr := optional
+ if fd.filesystem().opts.limitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: fd,
+ Offset: mr.Start,
+ Perms: hostarch.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (fd *specialFileFD) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// IncRef implements memmap.File.IncRef.
+func (fd *specialFileFD) IncRef(fr memmap.FileRange) {
+ fd.fileRefsMu.Lock()
+ defer fd.fileRefsMu.Unlock()
+ fd.fileRefs.IncRefAndAccount(fr)
+}
+
+// DecRef implements memmap.File.DecRef.
+func (fd *specialFileFD) DecRef(fr memmap.FileRange) {
+ fd.fileRefsMu.Lock()
+ defer fd.fileRefsMu.Unlock()
+ fd.fileRefs.DecRefAndAccount(fr)
+}
+
+// MapInternal implements memmap.File.MapInternal.
+func (fd *specialFileFD) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
+ fd.requireHostFD()
+ return fd.hostFileMapper.MapInternal(fr, int(fd.handle.fd), at.Write)
+}
+
+// FD implements memmap.File.FD.
+func (fd *specialFileFD) FD() int {
+ fd.requireHostFD()
+ return int(fd.handle.fd)
+}
+
+func (fd *specialFileFD) requireHostFD() {
+ if fd.handle.fd < 0 {
+ // This is possible if fd was successfully mmapped before saving, then
+ // was restored without a host FD. This is unrecoverable: without a
+ // host FD, we can't mmap this file post-restore.
+ panic("gofer.specialFileFD can no longer be memory-mapped without a host FD")
+ }
+}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 5955948f0..c12abdf33 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,10 +1,24 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "verity",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
go_library(
name = "verity",
srcs = [
+ "dentry_list.go",
"filesystem.go",
"save_restore.go",
"verity.go",
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index e147d6b07..52d47994d 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -66,40 +66,23 @@ func putDentrySlice(ds *[]*dentry) {
dentrySlicePool.Put(ds)
}
-// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
-// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls
+// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for
// writing.
//
// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
// but dentry slices are allocated lazily, and it's much easier to say "defer
-// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
-// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
// +checklocksrelease:fs.renameMu
-func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
return
}
- if len(**ds) != 0 {
- fs.renameMu.Lock()
- for _, d := range **ds {
- d.checkDropLocked(ctx)
- }
- fs.renameMu.Unlock()
- }
- putDentrySlice(*ds)
-}
-
-// +checklocksrelease:fs.renameMu
-func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
- if *ds == nil {
- fs.renameMu.Unlock()
- return
- }
for _, d := range **ds {
- d.checkDropLocked(ctx)
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
- fs.renameMu.Unlock()
putDentrySlice(*ds)
}
@@ -700,7 +683,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
}
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -712,7 +695,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -733,7 +716,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -770,7 +753,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if rp.Done() {
@@ -952,7 +935,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -982,7 +965,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statx{}, err
@@ -1028,7 +1011,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
if _, err := fs.resolveLocked(ctx, rp, &ds); err != nil {
return nil, err
}
@@ -1039,7 +1022,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -1055,7 +1038,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 23841ecf7..d2526263c 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -23,10 +23,12 @@
// Lock order:
//
// filesystem.renameMu
-// dentry.dirMu
-// fileDescription.mu
-// filesystem.verityMu
-// dentry.hashMu
+// dentry.cachingMu
+// filesystem.cacheMu
+// dentry.dirMu
+// fileDescription.mu
+// filesystem.verityMu
+// dentry.hashMu
//
// Locking dentry.dirMu in multiple dentries requires that parent dentries are
// locked before child dentries, and that filesystem.renameMu is locked to
@@ -96,6 +98,9 @@ const (
// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
// extended attributes. The maximum value of a 32 bit integer has 10 digits.
sizeOfStringInt32 = 10
+
+ // defaultMaxCachedDentries is the default limit of dentry cache.
+ defaultMaxCachedDentries = uint64(1000)
)
var (
@@ -106,9 +111,10 @@ var (
// Mount option names for verityfs.
const (
- moptLowerPath = "lower_path"
- moptRootHash = "root_hash"
- moptRootName = "root_name"
+ moptLowerPath = "lower_path"
+ moptRootHash = "root_hash"
+ moptRootName = "root_name"
+ moptDentryCacheLimit = "dentry_cache_limit"
)
// HashAlgorithm is a type specifying the algorithm used to hash the file
@@ -188,6 +194,17 @@ type filesystem struct {
// dentries.
renameMu sync.RWMutex `state:"nosave"`
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by cacheMu.
+ cacheMu sync.Mutex `state:"nosave"`
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // maxCachedDentries is the maximum size of filesystem.cachedDentries.
+ maxCachedDentries uint64
+
// verityMu synchronizes enabling verity files, protects files or
// directories from being enabled by different threads simultaneously.
// It also ensures that verity does not access files that are being
@@ -198,6 +215,10 @@ type filesystem struct {
// is for the whole file system to ensure that no more than one file is
// enabled the same time.
verityMu sync.RWMutex `state:"nosave"`
+
+ // released is nonzero once filesystem.Release has been called. It is accessed
+ // with atomic memory operations.
+ released int32
}
// InternalFilesystemOptions may be passed as
@@ -266,6 +287,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, moptRootName)
rootName = root
}
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts[moptDentryCacheLimit]; ok {
+ delete(mopts, moptDentryCacheLimit)
+ maxCD, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str)
+ return nil, nil, linuxerr.EINVAL
+ }
+ maxCachedDentries = maxCD
+ }
// Check for unparsed options.
if len(mopts) != 0 {
@@ -339,12 +370,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
action: iopts.Action,
opts: opts.Data,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
+ maxCachedDentries: maxCachedDentries,
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
// Construct the root dentry.
d := fs.newDentry()
- d.refs = 1
+ // Set the root's reference count to 2. One reference is returned to
+ // the caller, and the other is held by fs to prevent the root from
+ // being "cached" and subsequently evicted.
+ d.refs = 2
lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root())
lowerVD.IncRef()
d.lowerVD = lowerVD
@@ -519,7 +554,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
+ atomic.StoreInt32(&fs.released, 1)
fs.lowerMount.DecRef(ctx)
+
+ fs.renameMu.Lock()
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // An extra reference was held by the filesystem on the root to prevent
+ // it from being cached/evicted.
+ fs.rootDentry.DecRef(ctx)
}
// MountOptions implements vfs.FilesystemImpl.MountOptions.
@@ -533,6 +577,11 @@ func (fs *filesystem) MountOptions() string {
type dentry struct {
vfsd vfs.Dentry
+ // refs is the reference count. Each dentry holds a reference on its
+ // parent, even if disowned. When refs reaches 0, the dentry may be
+ // added to the cache or destroyed. If refs == -1, the dentry has
+ // already been destroyed. refs is accessed using atomic memory
+ // operations.
refs int64
// fs is the owning filesystem. fs is immutable.
@@ -587,13 +636,23 @@ type dentry struct {
// is protected by hashMu.
hashMu sync.RWMutex `state:"nosave"`
hash []byte
+
+ // cachingMu is used to synchronize concurrent dentry caching attempts on
+ // this dentry.
+ cachingMu sync.Mutex `state:"nosave"`
+
+ // If cached is true, dentryEntry links dentry into
+ // filesystem.cachedDentries. cached and dentryEntry are protected by
+ // cachingMu.
+ cached bool
+ dentryEntry
}
// newDentry creates a new dentry representing the given verity file. The
-// dentry initially has no references; it is the caller's responsibility to set
-// the dentry's reference count and/or call dentry.destroy() as appropriate.
-// The dentry is initially invalid in that it contains no underlying dentry;
-// the caller is responsible for setting them.
+// dentry initially has no references, but is not cached; it is the caller's
+// responsibility to set the dentry's reference count and/or call
+// dentry.destroy() as appropriate. The dentry is initially invalid in that it
+// contains no underlying dentry; the caller is responsible for setting them.
func (fs *filesystem) newDentry() *dentry {
d := &dentry{
fs: fs,
@@ -629,42 +688,23 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- r := atomic.AddInt64(&d.refs, -1)
- if d.LogRefs() {
- refsvfs2.LogDecRef(d, r)
- }
- if r == 0 {
- d.fs.renameMu.Lock()
- d.checkDropLocked(ctx)
- d.fs.renameMu.Unlock()
- } else if r < 0 {
- panic("verity.dentry.DecRef() called without holding a reference")
+ if d.decRefNoCaching() == 0 {
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
}
-func (d *dentry) decRefLocked(ctx context.Context) {
+// decRefNoCaching decrements d's reference count without calling
+// d.checkCachingLocked, even if d's reference count reaches 0; callers are
+// responsible for ensuring that d.checkCachingLocked will be called later.
+func (d *dentry) decRefNoCaching() int64 {
r := atomic.AddInt64(&d.refs, -1)
if d.LogRefs() {
refsvfs2.LogDecRef(d, r)
}
- if r == 0 {
- d.checkDropLocked(ctx)
- } else if r < 0 {
- panic("verity.dentry.decRefLocked() called without holding a reference")
+ if r < 0 {
+ panic("verity.dentry.decRefNoCaching() called without holding a reference")
}
-}
-
-// checkDropLocked should be called after d's reference count becomes 0 or it
-// becomes deleted.
-func (d *dentry) checkDropLocked(ctx context.Context) {
- // Dentries with a positive reference count must be retained. Dentries
- // with a negative reference count have already been destroyed.
- if atomic.LoadInt64(&d.refs) != 0 {
- return
- }
- // Refs is still zero; destroy it.
- d.destroyLocked(ctx)
- return
+ return r
}
// destroyLocked destroys the dentry.
@@ -683,6 +723,12 @@ func (d *dentry) destroyLocked(ctx context.Context) {
panic("verity.dentry.destroyLocked() called with references on the dentry")
}
+ // Drop the reference held by d on its parent without recursively
+ // locking d.fs.renameMu.
+ if d.parent != nil && d.parent.decRefNoCaching() == 0 {
+ d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
+ }
+
if d.lowerVD.Ok() {
d.lowerVD.DecRef(ctx)
}
@@ -695,7 +741,6 @@ func (d *dentry) destroyLocked(ctx context.Context) {
delete(d.parent.children, d.name)
}
d.parent.dirMu.Unlock()
- d.parent.decRefLocked(ctx)
}
refsvfs2.Unregister(d)
}
@@ -734,6 +779,140 @@ func (d *dentry) OnZeroWatches(context.Context) {
//TODO(b/159261227): Implement OnZeroWatches.
}
+// checkCachingLocked should be called after d's reference count becomes 0 or
+// it becomes disowned.
+//
+// For performance, checkCachingLocked can also be called after d's reference
+// count becomes non-zero, so that d can be removed from the LRU cache. This
+// may help in reducing the size of the cache and hence reduce evictions. Note
+// that this is not necessary for correctness.
+//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
+// Preconditions: d.fs.renameMu must be locked for writing if
+// renameMuWriteLocked is true; it may be temporarily unlocked.
+func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) {
+ d.cachingMu.Lock()
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ d.cachingMu.Unlock()
+ return
+ }
+ if refs > 0 {
+ // fs.cachedDentries is permitted to contain dentries with non-zero refs,
+ // which are skipped by fs.evictCachedDentryLocked() upon reaching the end
+ // of the LRU. But it is still beneficial to remove d from the cache as we
+ // are already holding d.cachingMu. Keeping a cleaner cache also reduces
+ // the number of evictions (which is expensive as it acquires fs.renameMu).
+ d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
+ return
+ }
+
+ if atomic.LoadInt32(&d.fs.released) != 0 {
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as
+ // needed by d.destroyLocked() later.
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ delete(d.parent.children, d.name)
+ d.parent.dirMu.Unlock()
+ }
+ d.destroyLocked(ctx) // +checklocksforce: see above.
+ return
+ }
+
+ d.fs.cacheMu.Lock()
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ shouldEvict := d.fs.cachedDentriesLen > d.fs.maxCachedDentries
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+
+ if shouldEvict {
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by
+ // d.evictCachedDentryLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
+ d.fs.evictCachedDentryLocked(ctx) // +checklocksforce: see above.
+ }
+}
+
+// Preconditions: d.cachingMu must be locked.
+func (d *dentry) removeFromCacheLocked() {
+ if d.cached {
+ d.fs.cacheMu.Lock()
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.fs.cacheMu.Unlock()
+ d.cached = false
+ }
+}
+
+// Precondition: fs.renameMu must be locked for writing; it may be temporarily
+// unlocked.
+// +checklocks:fs.renameMu
+func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+}
+
+// Preconditions:
+// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// +checklocks:fs.renameMu
+func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ fs.cacheMu.Lock()
+ victim := fs.cachedDentries.Back()
+ fs.cacheMu.Unlock()
+ if victim == nil {
+ // fs.cachedDentries may have become empty between when it was
+ // checked and when we locked fs.cacheMu.
+ return
+ }
+
+ victim.cachingMu.Lock()
+ victim.removeFromCacheLocked()
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) != 0 {
+ victim.cachingMu.Unlock()
+ return
+ }
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ victim.parent.dirMu.Unlock()
+ }
+ victim.cachingMu.Unlock()
+ victim.destroyLocked(ctx) // +checklocksforce: owned as precondition, victim.fs == fs.
+}
+
func (d *dentry) isSymlink() bool {
return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
}
diff --git a/pkg/sentry/kernel/ipc/object.go b/pkg/sentry/kernel/ipc/object.go
index 387b35e7e..facd157c7 100644
--- a/pkg/sentry/kernel/ipc/object.go
+++ b/pkg/sentry/kernel/ipc/object.go
@@ -19,6 +19,8 @@ package ipc
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -113,3 +115,36 @@ func (o *Object) CheckPermissions(creds *auth.Credentials, req fs.PermMask) bool
}
return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, o.UserNS)
}
+
+// Set modifies attributes for an IPC object. See *ctl(IPC_SET).
+//
+// Precondition: Mechanism.mu must be held.
+func (o *Object) Set(ctx context.Context, perm *linux.IPCPerm) error {
+ creds := auth.CredentialsFromContext(ctx)
+ uid := creds.UserNamespace.MapToKUID(auth.UID(perm.UID))
+ gid := creds.UserNamespace.MapToKGID(auth.GID(perm.GID))
+ if !uid.Ok() || !gid.Ok() {
+ // The man pages don't specify an errno for invalid uid/gid, but EINVAL
+ // is generally used for invalid arguments.
+ return linuxerr.EINVAL
+ }
+
+ if !o.CheckOwnership(creds) {
+ // "The argument cmd has the value IPC_SET or IPC_RMID, but the
+ // effective user ID of the calling process is not the creator (as
+ // found in msg_perm.cuid) or the owner (as found in msg_perm.uid)
+ // of the message queue, and the caller is not privileged (Linux:
+ // does not have the CAP_SYS_ADMIN capability)."
+ return linuxerr.EPERM
+ }
+
+ // User may only modify the lower 9 bits of the mode. All the other bits are
+ // always 0 for the underlying inode.
+ mode := linux.FileMode(perm.Mode & 0x1ff)
+
+ o.Perms = fs.FilePermsFromMode(mode)
+ o.Owner.UID = uid
+ o.Owner.GID = gid
+
+ return nil
+}
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index fab396d7c..7c459d076 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -206,6 +206,48 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {
return mech.(*Queue), nil
}
+// IPCInfo reports global parameters for message queues. See msgctl(IPC_INFO).
+func (r *Registry) IPCInfo(ctx context.Context) *linux.MsgInfo {
+ return &linux.MsgInfo{
+ MsgPool: linux.MSGPOOL,
+ MsgMap: linux.MSGMAP,
+ MsgMax: linux.MSGMAX,
+ MsgMnb: linux.MSGMNB,
+ MsgMni: linux.MSGMNI,
+ MsgSsz: linux.MSGSSZ,
+ MsgTql: linux.MSGTQL,
+ MsgSeg: linux.MSGSEG,
+ }
+}
+
+// MsgInfo reports global parameters for message queues. See msgctl(MSG_INFO).
+func (r *Registry) MsgInfo(ctx context.Context) *linux.MsgInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ var messages, bytes uint64
+ r.reg.ForAllObjects(
+ func(o ipc.Mechanism) {
+ q := o.(*Queue)
+ q.mu.Lock()
+ messages += q.messageCount
+ bytes += q.byteCount
+ q.mu.Unlock()
+ },
+ )
+
+ return &linux.MsgInfo{
+ MsgPool: int32(r.reg.ObjectCount()),
+ MsgMap: int32(messages),
+ MsgTql: int32(bytes),
+ MsgMax: linux.MSGMAX,
+ MsgMnb: linux.MSGMNB,
+ MsgMni: linux.MSGMNI,
+ MsgSsz: linux.MSGSSZ,
+ MsgSeg: linux.MSGSEG,
+ }
+}
+
// Send appends a message to the message queue, and returns an error if sending
// fails. See msgsnd(2).
func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error {
@@ -465,6 +507,73 @@ func (q *Queue) msgAtIndex(mType int64) *Message {
return msg
}
+// Set modifies some values of the queue. See msgctl(IPC_SET).
+func (q *Queue) Set(ctx context.Context, ds *linux.MsqidDS) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ creds := auth.CredentialsFromContext(ctx)
+ if ds.MsgQbytes > maxQueueBytes && !creds.HasCapabilityIn(linux.CAP_SYS_RESOURCE, q.obj.UserNS) {
+ // "An attempt (IPC_SET) was made to increase msg_qbytes beyond the
+ // system parameter MSGMNB, but the caller is not privileged (Linux:
+ // does not have the CAP_SYS_RESOURCE capability)."
+ return linuxerr.EPERM
+ }
+
+ if err := q.obj.Set(ctx, &ds.MsgPerm); err != nil {
+ return err
+ }
+
+ q.maxBytes = ds.MsgQbytes
+ q.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// Stat returns a MsqidDS object filled with information about the queue. See
+// msgctl(IPC_STAT) and msgctl(MSG_STAT).
+func (q *Queue) Stat(ctx context.Context) (*linux.MsqidDS, error) {
+ return q.stat(ctx, fs.PermMask{Read: true})
+}
+
+// StatAny is similar to Queue.Stat, but doesn't require read permission. See
+// msgctl(MSG_STAT_ANY).
+func (q *Queue) StatAny(ctx context.Context) (*linux.MsqidDS, error) {
+ return q.stat(ctx, fs.PermMask{})
+}
+
+// stat returns a MsqidDS object filled with information about the queue. An
+// error is returned if the user doesn't have the specified permissions.
+func (q *Queue) stat(ctx context.Context, mask fs.PermMask) (*linux.MsqidDS, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ creds := auth.CredentialsFromContext(ctx)
+ if !q.obj.CheckPermissions(creds, mask) {
+ // "The caller must have read permission on the message queue."
+ return nil, linuxerr.EACCES
+ }
+
+ return &linux.MsqidDS{
+ MsgPerm: linux.IPCPerm{
+ Key: uint32(q.obj.Key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Creator.GID)),
+ Mode: uint16(q.obj.Perms.LinuxMode()),
+ Seq: 0, // IPC sequences not supported.
+ },
+ MsgStime: q.sendTime.TimeT(),
+ MsgRtime: q.receiveTime.TimeT(),
+ MsgCtime: q.changeTime.TimeT(),
+ MsgCbytes: q.byteCount,
+ MsgQnum: q.messageCount,
+ MsgQbytes: q.maxBytes,
+ MsgLspid: q.sendPID,
+ MsgLrpid: q.receivePID,
+ }, nil
+}
+
// Lock implements ipc.Mechanism.Lock.
func (q *Queue) Lock() {
q.mu.Lock()
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 8a5c81a68..28e466948 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -336,19 +336,15 @@ func (s *Set) Size() int {
return len(s.sems)
}
-// Change changes some fields from the set atomically.
-func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.FileOwner, perms fs.FilePermissions) error {
+// Set modifies attributes for a semaphore set. See semctl(IPC_SET).
+func (s *Set) Set(ctx context.Context, ds *linux.SemidDS) error {
s.mu.Lock()
defer s.mu.Unlock()
- // "The effective UID of the calling process must match the owner or creator
- // of the semaphore set, or the caller must be privileged."
- if !s.obj.CheckOwnership(creds) {
- return linuxerr.EACCES
+ if err := s.obj.Set(ctx, &ds.SemPerm); err != nil {
+ return err
}
- s.obj.Owner = owner
- s.obj.Perms = perms
s.changeTime = ktime.NowFromContext(ctx)
return nil
}
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index b8da0c76c..ab938fa3c 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -618,25 +618,10 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
s.mu.Lock()
defer s.mu.Unlock()
- creds := auth.CredentialsFromContext(ctx)
- if !s.obj.CheckOwnership(creds) {
- return linuxerr.EPERM
- }
-
- uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
- gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
- if !uid.Ok() || !gid.Ok() {
- return linuxerr.EINVAL
+ if err := s.obj.Set(ctx, &ds.ShmPerm); err != nil {
+ return err
}
- // User may only modify the lower 9 bits of the mode. All the other bits are
- // always 0 for the underlying inode.
- mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff)
- s.obj.Perms = fs.FilePermsFromMode(mode)
-
- s.obj.Owner.UID = uid
- s.obj.Owner.GID = gid
-
s.changeTime = ktime.NowFromContext(ctx)
return nil
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 2eda15303..5814a4eca 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -489,11 +489,6 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
tg.signalHandlers.mu.Lock()
defer tg.signalHandlers.mu.Unlock()
- // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a
- // background process group in its session, and the calling process is not
- // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of
- // this background process group."
-
// tty must be the controlling terminal.
if tg.tty != tty {
return -1, linuxerr.ENOTTY
@@ -516,6 +511,16 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
return -1, linuxerr.EPERM
}
+ signalAction := tg.signalHandlers.actions[linux.SIGTTOU]
+ // If the calling process is a member of a background group, a SIGTTOU
+ // signal is sent to all members of this background process group.
+ // We need also need to check whether it is ignoring or blocking SIGTTOU.
+ ignored := signalAction.Handler == linux.SIG_IGN
+ blocked := tg.leader.signalMask == linux.SignalSetOf(linux.SIGTTOU)
+ if tg.processGroup.id != tg.processGroup.session.foreground.id && !ignored && !blocked {
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGTTOU), true)
+ }
+
tg.processGroup.session.foreground.id = pgid
return 0, nil
}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index b7f765cd7..d71d64580 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -77,15 +77,6 @@ func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64)
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize)
-
delete(mm.aioManager.contexts, id)
aioCtx.destroy()
return aioCtx
@@ -411,6 +402,15 @@ func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOC
return nil
}
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize)
+
mm.aioManager.mu.Lock()
defer mm.aioManager.mu.Unlock()
return mm.destroyAIOContextLocked(ctx, id)
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index 8fd8287b3..7a3c97c5a 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -55,11 +55,7 @@ func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virt
}
// Adjust the block to match our size.
- physicalStart = alignedPhysical & faultBlockMask
- if physicalStart < pr.physical {
- // Bound the starting point to the start of the region.
- physicalStart = pr.physical
- }
+ physicalStart = pr.physical + (alignedPhysical-pr.physical)&faultBlockMask
virtualStart = pr.virtual + (physicalStart - pr.physical)
physicalEnd := physicalStart + faultBlockSize
if physicalEnd > end {
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 56f90d952..2046a48b9 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -123,7 +123,7 @@ var AMD64 = &kernel.SyscallTable{
68: syscalls.Supported("msgget", Msgget),
69: syscalls.Supported("msgsnd", Msgsnd),
70: syscalls.Supported("msgrcv", Msgrcv),
- 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
+ 71: syscalls.Supported("msgctl", Msgctl),
72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
@@ -616,7 +616,7 @@ var ARM64 = &kernel.SyscallTable{
184: syscalls.ErrorWithEvent("mq_notify", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
185: syscalls.ErrorWithEvent("mq_getsetattr", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
186: syscalls.Supported("msgget", Msgget),
- 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
+ 187: syscalls.Supported("msgctl", Msgctl),
188: syscalls.Supported("msgrcv", Msgrcv),
189: syscalls.Supported("msgsnd", Msgsnd),
190: syscalls.Supported("semget", Semget),
diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go
index 5259ade90..60b989ee7 100644
--- a/pkg/sentry/syscalls/linux/sys_msgqueue.go
+++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go
@@ -130,12 +130,63 @@ func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wai
func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := ipc.ID(args[0].Int())
cmd := args[1].Int()
+ buf := args[2].Pointer()
creds := auth.CredentialsFromContext(t)
+ r := t.IPCNamespace().MsgqueueRegistry()
+
switch cmd {
+ case linux.IPC_INFO:
+ info := r.IPCInfo(t)
+ _, err := info.CopyOut(t, buf)
+ return 0, nil, err
+ case linux.MSG_INFO:
+ msgInfo := r.MsgInfo(t)
+ _, err := msgInfo.CopyOut(t, buf)
+ return 0, nil, err
case linux.IPC_RMID:
- return 0, nil, t.IPCNamespace().MsgqueueRegistry().Remove(id, creds)
+ return 0, nil, r.Remove(id, creds)
+ }
+
+ // Remaining commands use a queue.
+ queue, err := r.FindByID(id)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch cmd {
+ case linux.MSG_STAT:
+ // Technically, we should be treating id as "an index into the kernel's
+ // internal array that maintains information about all shared memory
+ // segments on the system". Since we don't track segments in an array,
+ // we'll just pretend the msqid is the index and do the same thing as
+ // IPC_STAT. Linux also uses the index as the msqid.
+ fallthrough
+ case linux.IPC_STAT:
+ stat, err := queue.Stat(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = stat.CopyOut(t, buf)
+ return 0, nil, err
+
+ case linux.MSG_STAT_ANY:
+ stat, err := queue.StatAny(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = stat.CopyOut(t, buf)
+ return 0, nil, err
+
+ case linux.IPC_SET:
+ var ds linux.MsqidDS
+ if _, err := ds.CopyIn(t, buf); err != nil {
+ return 0, nil, linuxerr.EINVAL
+ }
+ err := queue.Set(t, &ds)
+ return 0, nil, err
+
default:
return 0, nil, linuxerr.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index f61cc466c..5a119b21c 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
@@ -166,8 +165,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
- perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777))
- return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms)
+ return 0, nil, ipcSet(t, id, &s)
case linux.GETPID:
v, err := getPID(t, id, num)
@@ -243,24 +241,13 @@ func remove(t *kernel.Task, id ipc.ID) error {
return r.Remove(id, creds)
}
-func ipcSet(t *kernel.Task, id ipc.ID, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
+func ipcSet(t *kernel.Task, id ipc.ID, ds *linux.SemidDS) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
return linuxerr.EINVAL
}
-
- creds := auth.CredentialsFromContext(t)
- kuid := creds.UserNamespace.MapToKUID(uid)
- if !kuid.Ok() {
- return linuxerr.EINVAL
- }
- kgid := creds.UserNamespace.MapToKGID(gid)
- if !kgid.Ok() {
- return linuxerr.EINVAL
- }
- owner := fs.FileOwner{UID: kuid, GID: kgid}
- return set.Change(t, creds, owner, perms)
+ return set.Set(t, ds)
}
func ipcStat(t *kernel.Task, id ipc.ID) (*linux.SemidDS, error) {
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 5aad31b78..82ee2c521 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -1,9 +1,5 @@
# The gVisor Virtual Filesystem
-THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION
-USE. For the filesystem implementation currently used by gVisor, see the `fs`
-package.
-
## Implementation Notes
### Reference Counting
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index e0dfe5813..2f34bf8dd 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -729,7 +729,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
return
}
l := int(opts[i+1])
- if i < 2 || i+l > limit {
+ if l < 2 || i+l > limit {
return
}
i += l
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 95ade0e5c..1f18213e5 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -49,9 +49,9 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
- // unspecifiedEthernetAddress is the unspecified ethernet address
+ // UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
- unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
@@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
return false
}
- if addr == unspecifiedEthernetAddress {
+ if addr == UnspecifiedEthernetAddress {
return false
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index bf9ccbf1a..adc04e855 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
@@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f26c857eb..d02eea93c 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -290,3 +290,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b427c6170..8211a2031 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -42,6 +42,14 @@ type Endpoint struct {
nested.Endpoint
}
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ if l := e.Endpoint.LinkAddress(); len(l) != 0 {
+ return l
+ }
+ return header.UnspecifiedEthernetAddress
+}
+
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
// Capabilities implements stack.LinkEndpoint.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+ c := e.Endpoint.Capabilities()
+ if c&stack.CapabilityLoopback == 0 {
+ c |= stack.CapabilityResolutionRequired
+ }
+ return c
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- linkAddr := e.Endpoint.LinkAddress()
+ linkAddr := e.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
@@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 {
}
// ARPHardwareType implements stack.LinkEndpoint.
-func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone {
+ return a
+ }
return header.ARPHardwareEther
}
@@ -97,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP
}
eth.Encode(&fields)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.Endpoint.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 48356c343..058242f96 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
}
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 7012d8829..d7bbfa639 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -103,3 +103,6 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3e2a1aa94..844f5959b 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 3e816b0c7..14cb96d63 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -152,3 +152,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.child.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 5030b6ba1..3ed0aa3fe 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.
func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 40bd5560b..dc63e5fb0 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -228,3 +228,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.lower.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 30cf659b8..66efe6472 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -202,6 +202,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
eth.Encode(ethHdr)
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a95602aa5..13900205d 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -155,3 +155,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index a71400ee9..b0e4237bd 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe
return pkts.Len(), nil
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index 605e9ef8d..4d4d98caf 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade
func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 72f66441f..ccb69393b 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -342,6 +342,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p
return n, nil
}
+func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index dfe2c886f..57b3348b2 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -846,6 +846,14 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link.
+ //
+ // If the link-layer has its own header, the payload must already include the
+ // header.
+ //
+ // WriteRawPacket takes ownership of the packet.
+ WriteRawPacket(*PacketBuffer) tcpip.Error
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c73890c4c..cfa8a2e8f 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -119,8 +119,7 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // seed is a one-time random value initialized at stack startup
- // and is used to seed the TCP port picking on active connections
+ // seed is a one-time random value initialized at stack startup.
//
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
@@ -161,6 +160,10 @@ type Stack struct {
// This is required to prevent potential ACK loops.
// Setting this to 0 will disable all rate limiting.
tcpInvalidRateLimit time.Duration
+
+ // tsOffsetSecret is the secret key for generating timestamp offsets
+ // initialized at stack startup.
+ tsOffsetSecret uint32
}
// UniqueID is an abstract generator of unique identifiers.
@@ -384,6 +387,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
tcpInvalidRateLimit: defaultTCPInvalidRateLimit,
+ tsOffsetSecret: randomGenerator.Uint32(),
}
// Add specified network protocols.
@@ -1819,8 +1823,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
-// Seed returns a 32 bit value that can be used as a seed value for port
-// picking, ISN generation etc.
+// Seed returns a 32 bit value that can be used as a seed value.
//
// NOTE: The seed is generated once during stack initialization only.
func (s *Stack) Seed() uint32 {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index dda57e225..824cf6526 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -479,7 +479,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: d.stack.Seed(),
+ seed: d.stack.seed,
}
}
if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b3d8951ff..55854ba59 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -321,28 +321,26 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
defer route.Release()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = owner
+
if e.ops.GetHeaderIncluded() {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
return 0, err
}
- } else {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- pkt.Owner = owner
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- return 0, err
- }
+ return int64(len(payloadBytes)), nil
}
+ if err := route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
+ return 0, err
+ }
return int64(len(payloadBytes)), nil
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 8436d2cf0..c3922bbe5 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -96,6 +96,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index aa413ad05..f8269efa6 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 {
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
+ stack *stack.Stack
+ protocol *protocol
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
@@ -119,9 +120,10 @@ func timeStamp(clock tcpip.Clock) uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
+ protocol: protocol,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6Only: v6Only,
@@ -201,7 +203,7 @@ func (l *listenContext) useSynCookies() bool {
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -213,7 +215,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
return nil, err
}
- n := newEndpoint(l.stack, netProto, queue)
+ n := newEndpoint(l.stack, l.protocol, netProto, queue)
n.ops.SetV6Only(l.v6Only)
n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
@@ -244,10 +246,10 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret)
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -323,7 +325,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
@@ -495,7 +497,7 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
defer s.decRef()
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
@@ -581,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !ctx.useSynCookies() {
s.incRef()
atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, &opts)
+ return e.handleSynSegment(ctx, s, opts)
}
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
@@ -600,10 +602,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
+ if opts.TS {
+ // Create a barely-sufficient endpoint to calculate the TSVal.
+ pseudoEndpoint := endpoint{
+ TCPEndpointStateInner: stack.TCPEndpointStateInner{
+ TSOffset: e.protocol.tsOffset(s.dstAddr, s.srcAddr),
+ },
+ stack: e.stack,
+ }
+ synOpts.TSVal = pseudoEndpoint.tsValNow()
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
id: s.id,
@@ -670,7 +681,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
- rcvdSynOptions := &header.TCPSynOptions{
+ rcvdSynOptions := header.TCPSynOptions{
MSS: mssTable[data],
// Disable Window scaling as original SYN is
// lost.
@@ -725,25 +736,22 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
n.isRegistered = true
-
- // clear the tsOffset for the newly created
- // endpoint as the Timestamp was already
- // randomly offset when the original SYN-ACK was
- // sent above.
- n.TSOffset = 0
+ n.TSOffset = n.protocol.tsOffset(s.dstAddr, s.srcAddr)
// Switch state to connected.
n.isConnectNotified = true
- n.transitionToStateEstablishedLocked(&handshake{
- ep: n,
- iss: iss,
- ackNum: irs + 1,
- rcvWnd: seqnum.Size(n.initialReceiveWindow()),
- sndWnd: s.window,
- rcvWndScale: e.rcvWndScaleForHandshake(),
- sndWndScale: rcvdSynOptions.WS,
- mss: rcvdSynOptions.MSS,
- })
+ h := handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ sampleRTTWithTSOnly: true,
+ }
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -779,7 +787,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.ops.GetV6Only()
- ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+ ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 93ed161f9..f331655fc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -105,6 +105,11 @@ type handshake struct {
// sendSYNOpts is the cached values for the SYN options to be sent.
sendSYNOpts header.TCPSynOptions
+
+ // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't
+ // tell; then RTT can only be sampled when the incoming segment has timestamp
+ // options enabled.
+ sampleRTTWithTSOnly bool
}
func (e *endpoint) newHandshake() *handshake {
@@ -117,10 +122,12 @@ func (e *endpoint) newHandshake() *handshake {
h.resetState()
// Store reference to handshake state in endpoint.
e.h = h
+ // By the time handshake is created, e.ID is already initialized.
+ e.TSOffset = e.protocol.tsOffset(e.ID.LocalAddress, e.ID.RemoteAddress)
return h
}
-func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) *handshake {
h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
@@ -150,20 +157,23 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret)
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
- isnHasher.Write([]byte(id.LocalAddress))
- isnHasher.Write([]byte(id.RemoteAddress))
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = isnHasher.Write([]byte(id.LocalAddress))
+ _, _ = isnHasher.Write([]byte(id.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
// The time period here is 64ns. This is similar to what linux uses
// generate a sequence number that overlaps less than one
// time per MSL (2 minutes).
@@ -190,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -251,10 +261,10 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
rcvSynOpts := parseSynSegmentOptions(s)
// Remember if the Timestamp option was negotiated.
- h.ep.maybeEnableTimestamp(&rcvSynOpts)
+ h.ep.maybeEnableTimestamp(rcvSynOpts)
// Remember if the SACKPermitted option was negotiated.
- h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+ h.ep.maybeEnableSACKPermitted(rcvSynOpts)
// Remember the sequence we'll ack from now on.
h.ackNum = s.sequenceNumber + 1
@@ -266,8 +276,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// and the handshake is completed.
if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
-
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -283,7 +292,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
@@ -353,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: h.ep.SendTSOk,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
@@ -402,9 +411,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
+
h.state = handshakeCompleted
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -480,7 +490,7 @@ func (h *handshake) start() {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
@@ -557,6 +567,10 @@ func (h *handshake) complete() tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, h.sendSYNOpts)
+ // If we have ever retransmitted the SYN-ACK or
+ // SYN segment, we should only measure RTT if
+ // TS option is present.
+ h.sampleRTTWithTSOnly = true
}
case wakerForNotification:
@@ -600,6 +614,40 @@ func (h *handshake) complete() tcpip.Error {
return nil
}
+// transitionToStateEstablisedLocked transitions the endpoint of the handshake
+// to an established state given the last segment received from peer. It also
+// initializes sender/receiver.
+func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ now := h.ep.stack.Clock().NowMonotonic()
+
+ var rtt time.Duration
+ if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 {
+ rtt = h.ep.elapsed(now, s.parsedOptions.TSEcr)
+ }
+ if !h.sampleRTTWithTSOnly && rtt == 0 {
+ rtt = now.Sub(h.startTime)
+ }
+
+ if rtt > 0 {
+ h.ep.snd.updateRTO(rtt)
+ }
+
+ h.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ h.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ h.ep.setEndpointState(StateEstablished)
+}
+
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
@@ -873,7 +921,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
+ offset += header.EncodeTSOption(e.tsValNow(), e.recentTimestamp(), options[offset:])
}
if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -965,26 +1013,6 @@ func (e *endpoint) completeWorkerLocked() {
}
}
-// transitionToStateEstablisedLocked transitions a given endpoint
-// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver.
-func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
- e.rcvQueueInfo.rcvQueueMu.Unlock()
-
- e.setEndpointState(StateEstablished)
-}
-
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
@@ -1286,7 +1314,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoopDone is called at the end of protocolMainLoop.
// +checklocksrelease:e.mu
-func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) {
if e.snd != nil {
e.snd.resendTimer.cleanup()
e.snd.probeTimer.cleanup()
@@ -1331,7 +1359,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
return err
}
}
@@ -1559,7 +1587,7 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
return nil
case StateTimeWait:
fallthrough
@@ -1568,7 +1596,7 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
return nil
}
}
@@ -1592,13 +1620,13 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
return nil
}
e.transitionToStateCloseLocked()
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 355719beb..0623ee8ed 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"math"
- "math/rand"
"runtime"
"strings"
"sync/atomic"
@@ -378,6 +377,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
+ protocol *protocol `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
@@ -803,9 +803,10 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
+ stack: s,
+ protocol: protocol,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: header.TCPProtocolNumber,
@@ -874,7 +875,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset(e.stack.Rand())
+
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -2198,7 +2199,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h := jenkins.Sum32(e.stack.Seed())
+ h := jenkins.Sum32(e.protocol.portOffsetSecret)
for _, s := range [][]byte{
[]byte(e.ID.LocalAddress),
[]byte(e.ID.RemoteAddress),
@@ -2904,46 +2905,29 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
-func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) {
if synOpts.TS {
e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
-// timestamp returns the timestamp value to be used in the TSVal field of the
-// timestamp option for outgoing TCP segments for a given endpoint.
-func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
+func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 {
+ return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + e.TSOffset
}
-// tcpTimeStamp returns a timestamp offset by the provided offset. This is
-// not inlined above as it's used when SYN cookies are in use and endpoint
-// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
- d := curTime.Sub(tcpip.MonotonicTime{})
- return uint32(d.Milliseconds()) + offset
+func (e *endpoint) tsValNow() uint32 {
+ return e.tsVal(e.stack.Clock().NowMonotonic())
}
-// timeStampOffset returns a randomized timestamp offset to be used when sending
-// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset(rng *rand.Rand) uint32 {
- // Initialize a random tsOffset that will be added to the recentTS
- // everytime the timestamp is sent when the Timestamp option is enabled.
- //
- // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
- // why this is required.
- //
- // NOTE: This is not completely to spec as normally this should be
- // initialized in a manner analogous to how sequence numbers are
- // randomized per connection basis. But for now this is sufficient.
- return rng.Uint32()
+func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return time.Duration(e.tsVal(now)-tsEcr) * time.Millisecond
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
-func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableSACKPermitted(synOpts header.TCPSynOptions) {
var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 952ccacdd..f2e8b3840 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), &snd.probeWaker)
}
e.stack = s
+ e.protocol = protocolFromStack(s)
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := EndpointState(e.origEndpointState)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2e709ed78..128ef09e3 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
- listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
}
}
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 18b834243..b0ffd2429 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -49,10 +50,6 @@ const (
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
- // MaxUnprocessedSegments is the maximum number of unprocessed segments
- // that can be queued for a given endpoint.
- MaxUnprocessedSegments = 300
-
// DefaultTCPLingerTimeout is the amount of time that sockets linger in
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
@@ -96,6 +93,11 @@ type protocol struct {
maxRetries uint32
synRetries uint8
dispatcher dispatcher
+
+ // The following secrets are initialized once and stay unchanged after.
+ seqnumSecret uint32
+ portOffsetSecret uint32
+ tsOffsetSecret uint32
}
// Number returns the tcp protocol number.
@@ -105,7 +107,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
// NewEndpoint creates a new tcp endpoint.
func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newEndpoint(p.stack, netProto, waiterQueue), nil
+ return newEndpoint(p.stack, p, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
@@ -156,6 +158,24 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
return stack.UnknownDestinationPacketHandled
}
+func (p *protocol) tsOffset(src, dst tcpip.Address) uint32 {
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // TODO(https://gvisor.dev/issues/6473): This is not really secure as
+ // it does not use the recommended algorithm linked above.
+ h := jenkins.Sum32(p.tsOffsetSecret)
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = h.Write([]byte(src))
+ _, _ = h.Write([]byte(dst))
+ return h.Sum32()
+}
+
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
@@ -292,22 +312,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
case *tcpip.TCPMinRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.minRTO = MinRTO
+ } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO {
+ p.minRTO = minRTO
} else {
- p.minRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.maxRTO = MaxRTO
+ } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO {
+ p.maxRTO = maxRTO
} else {
- p.maxRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRetriesOption:
@@ -479,7 +503,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
maxRTO: MaxRTO,
maxRetries: MaxRetries,
recovery: tcpip.TCPRACKLossDetection,
+ seqnumSecret: s.Rand().Uint32(),
+ portOffsetSecret: s.Rand().Uint32(),
+ tsOffsetSecret: s.Rand().Uint32(),
}
p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
+
+// protocolFromStack retrieves the tcp.protocol instance from stack s.
+func protocolFromStack(s *stack.Stack) *protocol {
+ return s.TransportProtocolInstance(ProtocolNumber).(*protocol)
+}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0da4eafaa..3b055c294 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -80,7 +80,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -92,7 +91,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
+ if ackSeg.parsedOptions.TSEcr < rc.snd.ep.tsVal(seg.xmitTime) {
return
}
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 64302f576..2fabf1594 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -382,6 +382,9 @@ func (s *sender) updateRTO(rtt time.Duration) {
if s.RTO < s.minRTO {
s.RTO = s.minRTO
}
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
+ }
}
// resendSegment resends the first unacknowledged segment.
@@ -1342,10 +1345,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// some new data, i.e., only if it advances the left edge of
// the send window.
if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
- // TSVal/Ecr values sent by Netstack are at a millisecond
- // granularity.
- elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
- s.updateRTO(elapsed)
+ s.updateRTO(s.ep.elapsed(s.ep.stack.Clock().NowMonotonic(), rcvdSeg.parsedOptions.TSEcr))
}
if s.shouldSchedulePTO() {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 89e9fb886..c35db7c95 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,7 +33,6 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- latency = 5 * time.Millisecond
)
func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
@@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
if !enableRACK {
setStackTCPRecovery(t, c, 0)
}
- createConnectedWithSACKAndTS(c)
+ // The delay should be below initial RTO (1s) otherwise retransimission
+ // will start. Choose a relatively large value so that estimated RTT
+ // keeps high even after a few rounds of undelayed RTT samples.
+ c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
data := make([]byte, numPackets*maxPayload)
for i := range data {
@@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- // This delay is added to increase RTT as low RTT can cause TLP
- // before sending ACK.
- time.Sleep(latency)
}
return data
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 83e0653b9..6255355bb 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -35,13 +35,13 @@ import (
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
// option enabled if the stack in the context has SACK and TS enabled.
func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
@@ -108,7 +108,7 @@ func TestSackDisabledConnect(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
data := []byte{1, 2, 3}
@@ -170,7 +170,7 @@ func TestSackPermittedAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
@@ -244,7 +244,7 @@ func TestSackDisabledAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index db6b0955a..90b74a2a7 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -2143,7 +2144,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
}
- c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
@@ -2535,7 +2536,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Do 3-way handshake.
// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -3532,6 +3533,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -3554,8 +3561,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
),
)
}
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := 1 * time.Second
select {
case <-notifyCh:
case <-time.After((2 << numRetries) * initRTO):
@@ -3590,9 +3595,13 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- opt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
+ maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
}
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3618,8 +3627,8 @@ func TestMaxRTO(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
- if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
}
}
}
@@ -3670,6 +3679,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
@@ -4946,7 +4959,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
}
for i := start; i <= end; i++ {
- if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
+ if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
@@ -6304,7 +6317,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -6385,7 +6398,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// maximum buffer size defined above.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
// NOTE: The timestamp values in the sent packets are meaningless to the
// peer so we just increment the timestamp value by 1 every batch as we
@@ -6515,7 +6528,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
@@ -7430,6 +7443,11 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ initRTO := 1 * time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -7440,7 +7458,6 @@ func TestTCPUserTimeout(t *testing.T) {
// Ensure that on the next retransmit timer fire, the user timeout has
// expired.
- initRTO := 1 * time.Second
userTimeout := initRTO / 2
v := tcpip.TCPUserTimeoutOption(userTimeout)
if err := c.EP.SetSockOpt(&v); err != nil {
@@ -7954,6 +7971,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
+func TestHandshakeRTT(t *testing.T) {
+ type testCase struct {
+ connect bool
+ tsEnabled bool
+ useCookie bool
+ retrans bool
+ delay time.Duration
+ wantRTT time.Duration
+ }
+ var testCases []testCase
+ for _, connect := range []bool{false, true} {
+ for _, tsEnabled := range []bool{false, true} {
+ for _, useCookie := range []bool{false, true} {
+ for _, retrans := range []bool{false, true} {
+ if connect && useCookie {
+ continue
+ }
+ delay := 800 * time.Millisecond
+ if retrans {
+ delay = 1200 * time.Millisecond
+ }
+ wantRTT := delay
+ // If syncookie is enabled, sample RTT only when TS option is enabled.
+ if !retrans && useCookie && !tsEnabled {
+ wantRTT = 0
+ }
+ // If retransmitted, sample RTT only when TS option is enabled.
+ if retrans && !tsEnabled {
+ wantRTT = 0
+ }
+ testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
+ }
+ }
+ }
+ }
+ for _, tt := range testCases {
+ tt := tt
+ t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
+ t.Parallel()
+ c := context.New(t, defaultMTU)
+ if tt.useCookie {
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ }
+ synOpts := header.TCPSynOptions{}
+ if tt.tsEnabled {
+ synOpts.TS = true
+ synOpts.TSVal = 42
+ }
+ if tt.connect {
+ c.CreateConnectedWithOptions(synOpts, tt.delay)
+ } else {
+ synOpts.MSS = defaultIPv4MSS
+ synOpts.WS = -1
+ c.AcceptWithOptions(-1, synOpts, tt.delay)
+ }
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+ if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
+ t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
+ }
+ if info.RTTVar != 0 && tt.wantRTT == 0 {
+ t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
+ }
+ if info.RTTVar == 0 && tt.wantRTT != 0 {
+ t.Fatalf("got info.RTTVar=0, expect non zero")
+ }
+ })
+ }
+}
+
+func TestSetRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ for _, tt := range []struct {
+ name string
+ RTO time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ err tcpip.Error
+ }{
+ {
+ name: "invalid minRTO",
+ minRTO: maxRTO + time.Second,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "invalid maxRTO",
+ maxRTO: minRTO - time.Millisecond,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "valid minRTO",
+ minRTO: maxRTO - time.Second,
+ },
+ {
+ name: "valid maxRTO",
+ maxRTO: minRTO + time.Millisecond,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ var opt tcpip.SettableTransportProtocolOption
+ if tt.minRTO > 0 {
+ min := tcpip.TCPMinRTOOption(tt.minRTO)
+ opt = &min
+ }
+ if tt.maxRTO > 0 {
+ max := tcpip.TCPMaxRTOOption(tt.maxRTO)
+ opt = &max
+ }
+ err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
+ if got, want := err, tt.err; got != want {
+ t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
+ }
+ if tt.err == nil {
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ if tt.minRTO > 0 && tt.minRTO != minRTO {
+ t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
+ }
+ if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
+ t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
+ }
+ }
+ })
+ }
+}
+
+func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
+ t.Helper()
+ var minOpt tcpip.TCPMinRTOOption
+ var maxOpt tcpip.TCPMaxRTOOption
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
+ }
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
+ }
+ return time.Duration(minOpt), time.Duration(maxOpt)
+}
+
// generateRandomPayload generates a random byte slice of the specified length
// causing a fatal test failure if it is unable to do so.
func generateRandomPayload(t *testing.T, n int) []byte {
@@ -8047,7 +8209,7 @@ func TestSendBufferTuning(t *testing.T) {
if err := c.EP.GetSockOpt(&info); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
- outSz = (int64(info.SndCwnd) * packetOverheadFactor * (maxPayload))
+ outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
}
if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
@@ -8056,3 +8218,100 @@ func TestSendBufferTuning(t *testing.T) {
})
}
}
+
+func TestTimestampSynCookies(t *testing.T) {
+ clock := faketime.NewManualClock()
+ tsNow := func() uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
+ }
+ // Advance the clock so that NowMonotonic is non-zero.
+ clock.Advance(time.Second)
+ c := context.NewWithOpts(t, context.Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: defaultMTU,
+ Clock: clock,
+ })
+ defer c.Cleanup()
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(42, 0, tcpOpts[2:])
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss,
+ TCPOpts: tcpOpts[:],
+ })
+ // Get the TSVal of SYN-ACK.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ initialTSVal := tcpHdr.ParsedOptions().TSVal
+ // derive the tsOffset.
+ tsOffset := initialTSVal - tsNow()
+
+ header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ TCPOpts: tcpOpts[:],
+ })
+ c.EP, _, err = ep.Accept(nil)
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ defer wq.EventUnregister(&we)
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ } else if err != nil {
+ t.Fatalf("failed to accept: %s", err)
+ }
+
+ // Advance the clock again so that we expect the next TSVal to change.
+ clock.Advance(time.Second)
+ data := []byte{1, 2, 3}
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // The endpoint should have a correct TSOffset so that the received TSVal
+ // should match our expectation.
+ if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
+ t.Fatalf("got TSVal = %d, want %d", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 1deb1fe4d..65925daa5 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -32,7 +32,7 @@ import (
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
@@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
@@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
}
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 96e4849d2..6e55a7a32 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -122,6 +122,9 @@ type Options struct {
// MTU indicates the maximum transmission unit on the link layer.
MTU uint32
+
+ // Clock that is used by Stack.
+ Clock tcpip.Clock
}
// Context provides an initialized Network stack and a link layer endpoint
@@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
stackOpts := stack.Options{
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: opts.Clock,
}
if opts.EnableV4 {
stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
@@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
)
}
+// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
+// without delay.
+func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
+ return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
+// the connection. It delays before a SYNACK is sent. This makes c.EP have a
+// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
+// reduce flakiness.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
@@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// TS value.
mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
+ synChecker := checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
)
+ checker.IPv4(c.t, b, synChecker)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(TestInitialSequenceNumber)
+ if delay > 0 {
+ // Sleep so that RTT is increased.
+ time.Sleep(delay)
+ }
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
})
// Read ACK.
- ackPacket := c.GetPacket()
+ var ackPacket []byte
+ // Ignore retransimitted SYN packets.
+ for {
+ packet := c.GetPacket()
+ if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
+ checker.IPv4(c.t, packet, synChecker)
+ } else {
+ ackPacket = packet
+ break
+ }
+ }
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
@@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}
}
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
+// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
+func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with
+// the provided options enabled. It delays before the final ACK of the 3WHS is
+// sent. It also verifies that the SYN-ACK has the expected values for the
+// provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
+// the enabled options. The final ACK of the handshake is delayed by specified
+// duration.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
@@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
@@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
ackHeaders.TCPOpts = opts[:]
}
- // Send ACK.
+ // Send ACK, delay if needed.
+ if delay > 0 {
+ time.Sleep(delay)
+ }
c.SendPacket(nil, ackHeaders)
c.RcvdWindowScale = uint8(rcvdSynOptions.WS)