summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/ioctl.go15
-rw-r--r--pkg/cpuid/cpuid.go8
-rw-r--r--pkg/cpuid/cpuid_arm64.go14
-rw-r--r--pkg/cpuid/cpuid_x86.go15
-rw-r--r--pkg/sentry/arch/fpu/BUILD1
-rw-r--r--pkg/sentry/arch/fpu/fpu.go13
-rw-r--r--pkg/sentry/arch/fpu/fpu_unsafe.go31
-rw-r--r--pkg/sentry/fs/proc/exec_args.go2
-rw-r--r--pkg/sentry/fs/proc/task.go116
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go60
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go9
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go28
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go61
-rw-r--r--pkg/sentry/kernel/task_exec.go8
-rw-r--r--pkg/sentry/kernel/task_exit.go9
-rw-r--r--pkg/sentry/kernel/task_image.go2
-rw-r--r--pkg/sentry/loader/loader.go10
-rw-r--r--pkg/sentry/loader/vdso.go40
-rw-r--r--pkg/sentry/mm/mm.go1
-rw-r--r--pkg/sentry/mm/pma.go41
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go10
-rw-r--r--pkg/sentry/socket/socket.go1
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go2
-rw-r--r--pkg/shim/runtimeoptions/runtimeoptions.go7
-rw-r--r--pkg/syserr/BUILD1
-rw-r--r--pkg/syserr/syserr.go16
-rw-r--r--pkg/tcpip/internal/tcp/BUILD12
-rw-r--r--pkg/tcpip/internal/tcp/tcp.go48
-rw-r--r--pkg/tcpip/link/channel/channel.go3
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go5
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go31
-rw-r--r--pkg/tcpip/link/muxed/injectable.go5
-rw-r--r--pkg/tcpip/link/nested/nested.go5
-rw-r--r--pkg/tcpip/link/pipe/pipe.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go5
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go5
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go5
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/forwarding_test.go4
-rw-r--r--pkg/tcpip/stack/nic.go54
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go73
-rw-r--r--pkg/tcpip/stack/tcp.go3
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/datagram.go49
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go1
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD43
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go722
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go56
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go209
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go48
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go11
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go5
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go205
-rw-r--r--pkg/tcpip/transport/transport.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go859
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go65
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go21
-rw-r--r--runsc/boot/loader.go23
-rw-r--r--runsc/config/config.go3
-rw-r--r--runsc/config/flags.go1
-rw-r--r--test/packetimpact/dut/posix_server.cc2
-rw-r--r--test/runner/main.go1
-rw-r--r--test/syscalls/linux/BUILD3
-rw-r--r--test/syscalls/linux/eventfd.cc25
-rw-r--r--test/syscalls/linux/inotify.cc28
-rw-r--r--test/syscalls/linux/packet_socket.cc8
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc8
-rw-r--r--test/syscalls/linux/sendfile.cc100
-rw-r--r--test/syscalls/linux/socket_unix.cc15
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc15
-rw-r--r--test/syscalls/linux/splice.cc103
-rw-r--r--test/util/BUILD4
-rw-r--r--tools/bazeldefs/cc.bzl1
-rw-r--r--tools/defs.bzl3
-rw-r--r--website/_config.yml3
-rw-r--r--website/assets/images/2021-08-31-rack-figure1.pngbin0 -> 111367 bytes
-rw-r--r--website/assets/images/2021-08-31-rack-figure2.pngbin0 -> 71529 bytes
-rw-r--r--website/assets/images/2021-08-31-rack-figure3.pngbin0 -> 64347 bytes
-rw-r--r--website/blog/2021-08-31-gvisor-rack.md120
-rw-r--r--website/blog/BUILD10
91 files changed, 2353 insertions, 1266 deletions
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 006b5a525..29062c97a 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -170,3 +170,18 @@ const (
KCOV_MODE_TRACE_PC = 2
KCOV_MODE_TRACE_CMP = 3
)
+
+// Attestation ioctls.
+var (
+ SIGN_ATTESTATION_REPORT = IOC(_IOC_READ, 's', 1, 65)
+)
+
+// SizeOfQuoteInputData is the number of bytes in the input data of ioctl call
+// to get quote.
+const SizeOfQuoteInputData = 64
+
+// SignReport is a struct that gets signed quote from input data.
+type SignReport struct {
+ data [64]byte
+ quote []byte
+}
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 69eeb7528..4d5e062a8 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -37,6 +37,14 @@ package cpuid
// arch/arm64/include/uapi/asm/hwcap.h
type Feature int
+// HostFeatureSet returns a FeatureSet that matches that of the host machine.
+// Callers must not mutate the returned FeatureSet.
+func HostFeatureSet() *FeatureSet {
+ return hostFeatureSet
+}
+
+var hostFeatureSet = getHostFeatureSet()
+
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
// subset of the host feature set.
type ErrIncompatible struct {
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
index 6e61d562f..04645aed5 100644
--- a/pkg/cpuid/cpuid_arm64.go
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -230,6 +230,16 @@ type FeatureSet struct {
CPURevision uint8
}
+// Clone returns a copy of fs.
+func (fs *FeatureSet) Clone() *FeatureSet {
+ fs2 := *fs
+ fs2.Set = make(map[Feature]bool)
+ for f, b := range fs.Set {
+ fs2.Set[f] = b
+ }
+ return &fs2
+}
+
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
// Noop on arm64.
func (fs *FeatureSet) CheckHostCompatible() error {
@@ -292,9 +302,9 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
}
-// HostFeatureSet uses hwCap to get host values and construct a feature set
+// getHostFeatureSet uses hwCap to get host values and construct a feature set
// that matches that of the host machine.
-func HostFeatureSet() *FeatureSet {
+func getHostFeatureSet() *FeatureSet {
s := make(map[Feature]bool)
for f := range arm64FeatureStrings {
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index dc17cade8..a92d32d74 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -627,6 +627,17 @@ type FeatureSet struct {
CacheLine uint32
}
+// Clone returns a copy of fs.
+func (fs *FeatureSet) Clone() *FeatureSet {
+ fs2 := *fs
+ fs2.Set = make(map[Feature]bool)
+ for f, b := range fs.Set {
+ fs2.Set[f] = b
+ }
+ fs2.Caches = append([]Cache(nil), fs.Caches...)
+ return &fs2
+}
+
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
// equivalent to the "flags" field in /proc/cpuinfo.
func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
@@ -961,13 +972,13 @@ func (fs *FeatureSet) UseXsaveopt() bool {
// HostID executes a native CPUID instruction.
func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
-// HostFeatureSet uses cpuid to get host values and construct a feature set
+// getHostFeatureSet uses cpuid to get host values and construct a feature set
// that matches that of the host machine. Note that there are several places
// where there appear to be some unnecessary assignments between register names
// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
// where the different feature blocks come from, to make the code easier to
// inspect and read.
-func HostFeatureSet() *FeatureSet {
+func getHostFeatureSet() *FeatureSet {
// eax=0 gets max supported feature and vendor ID.
_, bx, cx, dx := HostID(0, 0)
vendorID := vendorIDFromRegs(bx, cx, dx)
diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD
index 6cdd21b1b..1f371e513 100644
--- a/pkg/sentry/arch/fpu/BUILD
+++ b/pkg/sentry/arch/fpu/BUILD
@@ -9,6 +9,7 @@ go_library(
"fpu_amd64.go",
"fpu_amd64.s",
"fpu_arm64.go",
+ "fpu_unsafe.go",
],
visibility = ["//:sandbox"],
deps = [
diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go
index 867d309a3..62bde19d3 100644
--- a/pkg/sentry/arch/fpu/fpu.go
+++ b/pkg/sentry/arch/fpu/fpu.go
@@ -17,7 +17,6 @@ package fpu
import (
"fmt"
- "reflect"
)
// State represents floating point state.
@@ -40,15 +39,3 @@ type ErrLoadingState struct {
func (e ErrLoadingState) Error() string {
return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures)
}
-
-// alignedBytes returns a slice of size bytes, aligned in memory to the given
-// alignment. This is used because we require certain structures to be aligned
-// in a specific way (for example, the X86 floating point data).
-func alignedBytes(size, alignment uint) []byte {
- data := make([]byte, size+alignment-1)
- offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
- if offset == 0 {
- return data[:size:size]
- }
- return data[alignment-offset:][:size:size]
-}
diff --git a/pkg/sentry/arch/fpu/fpu_unsafe.go b/pkg/sentry/arch/fpu/fpu_unsafe.go
new file mode 100644
index 000000000..c91dc99be
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fpu
+
+import (
+ "unsafe"
+)
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(uintptr(unsafe.Pointer(&data[0])) % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 379429ab2..75dc5d204 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -107,7 +107,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
return 0, linuxerr.EINVAL
}
- m, err := getTaskMM(f.t)
+ m, err := getTaskMMIncRef(f.t)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 89a799b21..03f2a882d 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -41,10 +41,24 @@ import (
// LINT.IfChange
-// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
-// users count is incremented, and must be decremented by the caller when it is
-// no longer in use.
-func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
+// getTaskMM gets the kernel task's MemoryManager. No additional reference is
+// taken on mm here. This is safe because MemoryManager.destroy is required to
+// leave the MemoryManager in a state where it's still usable as a
+// DynamicBytesSource.
+func getTaskMM(t *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// getTaskMMIncRef returns t's MemoryManager. If getTaskMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getTaskMMIncRef(t *kernel.Task) (*mm.MemoryManager, error) {
if t.ExitState() == kernel.TaskExitDead {
return nil, linuxerr.ESRCH
}
@@ -269,21 +283,18 @@ func (e *exe) executable() (file fsbridge.File, err error) {
if err := checkTaskState(e.t); err != nil {
return nil, err
}
- e.t.WithMuLocked(func(t *kernel.Task) {
- mm := t.MemoryManager()
- if mm == nil {
- err = linuxerr.EACCES
- return
- }
+ mm := getTaskMM(e.t)
+ if mm == nil {
+ return nil, linuxerr.EACCES
+ }
- // The MemoryManager may be destroyed, in which case
- // MemoryManager.destroy will simply set the executable to nil
- // (with locks held).
- file = mm.Executable()
- if file == nil {
- err = linuxerr.ESRCH
- }
- })
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = linuxerr.ESRCH
+ }
return
}
@@ -463,7 +474,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
if dst.NumBytes() == 0 {
return 0, nil
}
- mm, err := getTaskMM(m.t)
+ mm, err := getTaskMMIncRef(m.t)
if err != nil {
return 0, nil
}
@@ -494,22 +505,9 @@ func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inod
return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t)
}
-func (md *mapsData) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- md.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- // No additional reference is taken on mm here. This is safe
- // because MemoryManager.destroy is required to leave the
- // MemoryManager in a state where it's still usable as a SeqSource.
- tmm = mm
- }
- })
- return tmm
-}
-
// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
func (md *mapsData) NeedsUpdate(generation int64) bool {
- if mm := md.mm(); mm != nil {
+ if mm := getTaskMM(md.t); mm != nil {
return mm.NeedsUpdate(generation)
}
return true
@@ -517,7 +515,7 @@ func (md *mapsData) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
- if mm := md.mm(); mm != nil {
+ if mm := getTaskMM(md.t); mm != nil {
return mm.ReadMapsSeqFileData(ctx, h)
}
return []seqfile.SeqData{}, 0
@@ -534,22 +532,9 @@ func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Ino
return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t)
}
-func (sd *smapsData) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- sd.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- // No additional reference is taken on mm here. This is safe
- // because MemoryManager.destroy is required to leave the
- // MemoryManager in a state where it's still usable as a SeqSource.
- tmm = mm
- }
- })
- return tmm
-}
-
// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
func (sd *smapsData) NeedsUpdate(generation int64) bool {
- if mm := sd.mm(); mm != nil {
+ if mm := getTaskMM(sd.t); mm != nil {
return mm.NeedsUpdate(generation)
}
return true
@@ -557,7 +542,7 @@ func (sd *smapsData) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
- if mm := sd.mm(); mm != nil {
+ if mm := getTaskMM(sd.t); mm != nil {
return mm.ReadSmapsSeqFileData(ctx, h)
}
return []seqfile.SeqData{}, 0
@@ -627,12 +612,10 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
@@ -677,12 +660,10 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([
}
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
var buf bytes.Buffer
fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
@@ -734,12 +715,13 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
if fdTable := t.FDTable(); fdTable != nil {
fds = fdTable.CurrentMaxFDs()
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
- }
})
+
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
fmt.Fprintf(&buf, "FDSize:\t%d\n", fds)
fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10)
fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10)
@@ -925,7 +907,7 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
return 0, linuxerr.EINVAL
}
- m, err := getTaskMM(f.t)
+ m, err := getTaskMMIncRef(f.t)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index bd6b30397..43440ec19 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -995,7 +995,7 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error {
if d.writeFD < 0 {
d.handleMu.RUnlock()
// Ask the gofer if we don't have a host FD.
- return d.updateFromGetattrLocked(ctx)
+ return d.updateFromGetattrLocked(ctx, p9file{})
}
var stat unix.Statx_t
@@ -1014,33 +1014,35 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// updating stale attributes in d.updateFromP9AttrsLocked().
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
- return d.updateFromGetattrLocked(ctx)
+ return d.updateFromGetattrLocked(ctx, p9file{})
}
// Preconditions:
// * !d.isSynthetic().
// * d.metadataMu is locked.
// +checklocks:d.metadataMu
-func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
- // Use d.readFile or d.writeFile, which represent 9P FIDs that have been
- // opened, in preference to d.file, which represents a 9P fid that has not.
- // This may be significantly more efficient in some implementations. Prefer
- // d.writeFile over d.readFile since some filesystem implementations may
- // update a writable handle's metadata after writes to that handle, without
- // making metadata updates immediately visible to read-only handles
- // representing the same file.
- d.handleMu.RLock()
- handleMuRLocked := true
- var file p9file
- switch {
- case !d.writeFile.isNil():
- file = d.writeFile
- case !d.readFile.isNil():
- file = d.readFile
- default:
- file = d.file
- d.handleMu.RUnlock()
- handleMuRLocked = false
+func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error {
+ handleMuRLocked := false
+ if file.isNil() {
+ // Use d.readFile or d.writeFile, which represent 9P FIDs that have
+ // been opened, in preference to d.file, which represents a 9P fid that
+ // has not. This may be significantly more efficient in some
+ // implementations. Prefer d.writeFile over d.readFile since some
+ // filesystem implementations may update a writable handle's metadata
+ // after writes to that handle, without making metadata updates
+ // immediately visible to read-only handles representing the same file.
+ d.handleMu.RLock()
+ switch {
+ case !d.writeFile.isNil():
+ file = d.writeFile
+ handleMuRLocked = true
+ case !d.readFile.isNil():
+ file = d.readFile
+ handleMuRLocked = true
+ default:
+ file = d.file
+ d.handleMu.RUnlock()
+ }
}
_, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
@@ -2044,9 +2046,17 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
d := fd.dentry()
const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
- // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
- // available?
- if err := d.updateFromGetattr(ctx); err != nil {
+ // Use specialFileFD.handle.file for the getattr if available, for the
+ // same reason that we try to use open file handles in
+ // dentry.updateFromGetattrLocked().
+ var file p9file
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ file = sffd.handle.file
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromGetattrLocked(ctx, file)
+ d.metadataMu.Unlock()
+ if err != nil {
return linux.Statx{}, err
}
}
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 618092ef1..520487066 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -36,6 +36,10 @@ func (d *dentry) isCopiedUp() bool {
//
// Preconditions: filesystem.renameMu must be locked.
func (d *dentry) copyUpLocked(ctx context.Context) error {
+ return d.copyUpMaybeSyntheticMountpointLocked(ctx, false /* forSyntheticMountpoint */)
+}
+
+func (d *dentry) copyUpMaybeSyntheticMountpointLocked(ctx context.Context, forSyntheticMountpoint bool) error {
// Fast path.
if d.isCopiedUp() {
return nil
@@ -59,7 +63,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
// d is a filesystem root with no upper layer.
return linuxerr.EROFS
}
- if err := d.parent.copyUpLocked(ctx); err != nil {
+ if err := d.parent.copyUpMaybeSyntheticMountpointLocked(ctx, forSyntheticMountpoint); err != nil {
return err
}
@@ -168,7 +172,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
case linux.S_IFDIR:
if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{
- Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ ForSyntheticMountpoint: forSyntheticMountpoint,
}); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index c04c80590..3b3dcf836 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -462,13 +462,21 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
return d, nil
}
+type createType int
+
+const (
+ createNonDirectory createType = iota
+ createDirectory
+ createSyntheticMountpoint
+)
+
// doCreateAt checks that creating a file at rp is permitted, then invokes
// create to do so.
//
// Preconditions:
// * !rp.Done().
// * For the final path component in rp, !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, ct createType, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
@@ -504,7 +512,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return linuxerr.EEXIST
}
- if !dir && rp.MustBeDir() {
+ if ct == createNonDirectory && rp.MustBeDir() {
return linuxerr.ENOENT
}
@@ -518,7 +526,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
}
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
- if err := parent.copyUpLocked(ctx); err != nil {
+ if err := parent.copyUpMaybeSyntheticMountpointLocked(ctx, ct == createSyntheticMountpoint); err != nil {
return err
}
@@ -529,7 +537,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
parent.dirents = nil
ev := linux.IN_CREATE
- if dir {
+ if ct != createNonDirectory {
ev |= linux.IN_ISDIR
}
parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */)
@@ -618,7 +626,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
if rp.Mount() != vd.Mount() {
return linuxerr.EXDEV
}
@@ -671,7 +679,11 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ ct := createDirectory
+ if opts.ForSyntheticMountpoint {
+ ct = createSyntheticMountpoint
+ }
+ return fs.doCreateAt(ctx, rp, ct, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
pop := vfs.PathOperation{
Root: parent.upperVD,
@@ -722,7 +734,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
// Disallow attempts to create whiteouts.
if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 {
return linuxerr.EPERM
@@ -1476,7 +1488,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
pop := vfs.PathOperation{
Root: parent.upperVD,
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 34b0c4f63..d3f9cf489 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -40,7 +40,7 @@ import (
// Linux 3.18, the limit is five lines." - user_namespaces(7)
const maxIDMapLines = 5
-// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// getMM gets the kernel task's MemoryManager. No additional reference is taken on
// mm here. This is safe because MemoryManager.destroy is required to leave the
// MemoryManager in a state where it's still usable as a DynamicBytesSource.
func getMM(task *kernel.Task) *mm.MemoryManager {
@@ -608,12 +608,10 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
var vss, rss uint64
- s.task.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
@@ -649,13 +647,10 @@ var _ dynamicInode = (*statmData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var vss, rss uint64
- s.task.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
-
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
return nil
}
@@ -779,12 +774,12 @@ func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
if fdTable := t.FDTable(); fdTable != nil {
fds = fdTable.CurrentMaxFDs()
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
- }
})
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
// Filesystem user/group IDs aren't implemented; effective UID/GID are used
// instead.
fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid)
@@ -945,25 +940,17 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent
return vfs.VirtualDentry{}, "", err
}
- var err error
- var exec fsbridge.File
- s.task.WithMuLocked(func(t *kernel.Task) {
- mm := t.MemoryManager()
- if mm == nil {
- err = linuxerr.EACCES
- return
- }
+ mm := getMM(s.task)
+ if mm == nil {
+ return vfs.VirtualDentry{}, "", linuxerr.EACCES
+ }
- // The MemoryManager may be destroyed, in which case
- // MemoryManager.destroy will simply set the executable to nil
- // (with locks held).
- exec = mm.Executable()
- if exec == nil {
- err = linuxerr.ESRCH
- }
- })
- if err != nil {
- return vfs.VirtualDentry{}, "", err
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ exec := mm.Executable()
+ if exec == nil {
+ return vfs.VirtualDentry{}, "", linuxerr.ESRCH
}
defer exec.DecRef(ctx)
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 9175b911c..db91fc4d8 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -222,9 +222,15 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
// Update credentials to reflect the execve. This should precede switching
// MMs to ensure that dumpability has been reset first, if needed.
t.updateCredsForExecLocked()
- t.image.release()
+ oldImage := t.image
t.image = *r.image
t.mu.Unlock()
+
+ // Don't hold t.mu while calling t.image.release(), that may
+ // attempt to acquire TaskImage.MemoryManager.mappingMu, a lock order
+ // violation.
+ oldImage.release()
+
t.unstopVforkParent()
t.p.FullStateChanged()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 342e5debe..b3931445b 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -230,9 +230,16 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.Lock()
t.updateRSSLocked()
t.tg.pidns.owner.mu.Unlock()
+
+ // Release the task image resources. Accessing these fields must be
+ // done with t.mu held, but the mm.DecUsers() call must be done outside
+ // of that lock.
t.mu.Lock()
- t.image.release()
+ mm := t.image.MemoryManager
+ t.image.MemoryManager = nil
+ t.image.fu = nil
t.mu.Unlock()
+ mm.DecUsers(t)
// Releasing the MM unblocks a blocked CLONE_VFORK parent.
t.unstopVforkParent()
diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go
index c132c27ef..6002ffb42 100644
--- a/pkg/sentry/kernel/task_image.go
+++ b/pkg/sentry/kernel/task_image.go
@@ -53,7 +53,7 @@ type TaskImage struct {
}
// release releases all resources held by the TaskImage. release is called by
-// the task when it execs into a new TaskImage or exits.
+// the task when it execs into a new TaskImage.
func (image *TaskImage) release() {
// Nil out pointers so that if the task is saved after release, it doesn't
// follow the pointers to possibly now-invalid objects.
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 6a356779c..2759ef71e 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -295,15 +295,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
m.SetExecutable(ctx, file)
-
- symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
- if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux())
- }
-
- // Found rt_sigretrun.
- addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink
- m.SetVDSOSigReturn(addr)
+ m.SetVDSOSigReturn(uint64(vdsoAddr) + vdsoSigreturnOffset - vdsoPrelink)
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 3abd2ee7d..bcee6aef6 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -19,7 +19,6 @@ import (
"debug/elf"
"fmt"
"io"
- "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/context"
@@ -177,27 +176,6 @@ type VDSO struct {
phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
}
-// getSymbolValueFromVDSO returns the specific symbol value in vdso.so.
-func getSymbolValueFromVDSO(symbol string) (uint64, error) {
- f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary))
- if err != nil {
- return 0, err
- }
- syms, err := f.Symbols()
- if err != nil {
- return 0, err
- }
-
- for _, sym := range syms {
- if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF {
- if strings.Contains(sym.Name, symbol) {
- return sym.Value, nil
- }
- }
- }
- return 0, fmt.Errorf("no %v in vdso.so", symbol)
-}
-
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
@@ -388,3 +366,21 @@ func (v *VDSO) Release(ctx context.Context) {
v.ParamPage.DecRef(ctx)
v.vdso.DecRef(ctx)
}
+
+var vdsoSigreturnOffset = func() uint64 {
+ f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary))
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse vdso.so as ELF file: %v", err))
+ }
+ syms, err := f.Symbols()
+ if err != nil {
+ panic(fmt.Sprintf("failed to read symbols from vdso.so: %v", err))
+ }
+ const sigreturnSymbol = "__kernel_rt_sigreturn"
+ for _, sym := range syms {
+ if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF && sym.Name == sigreturnSymbol {
+ return sym.Value
+ }
+ }
+ panic(fmt.Sprintf("no symbol %q in vdso.so", sigreturnSymbol))
+}()
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 57969b26c..0fca59b64 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -28,6 +28,7 @@
// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
+// kernel.TaskSet.mu
//
// Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in
// multiple mm.MemoryManagers, as it does so in a well-defined order (forked
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 9f4cc238f..05cdcd8ae 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -324,20 +324,37 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma))
}
}
- // The majority of copy-on-write breaks on executable pages
- // come from:
- //
- // - The ELF loader, which must zero out bytes on the last
- // page of each segment after the end of the segment.
- //
- // - gdb's use of ptrace to insert breakpoints.
- //
- // Neither of these cases has enough spatial locality to
- // benefit from copying nearby pages, so if the vma is
- // executable, only copy the pages required.
var copyAR hostarch.AddrRange
- if vseg.ValuePtr().effectivePerms.Execute {
+ if vma := vseg.ValuePtr(); vma.effectivePerms.Execute {
+ // The majority of copy-on-write breaks on executable
+ // pages come from:
+ //
+ // - The ELF loader, which must zero out bytes on the
+ // last page of each segment after the end of the
+ // segment.
+ //
+ // - gdb's use of ptrace to insert breakpoints.
+ //
+ // Neither of these cases has enough spatial locality
+ // to benefit from copying nearby pages, so if the vma
+ // is executable, only copy the pages required.
copyAR = pseg.Range().Intersect(ar)
+ } else if vma.growsDown {
+ // In most cases, the new process will not use most of
+ // its stack before exiting or invoking execve(); it is
+ // especially unlikely to return very far down its call
+ // stack, since async-signal-safety concerns in
+ // multithreaded programs prevent the new process from
+ // being able to do much. So only copy up to one page
+ // before and after the pages required.
+ stackMaskAR := ar
+ if newStart := stackMaskAR.Start - hostarch.PageSize; newStart < stackMaskAR.Start {
+ stackMaskAR.Start = newStart
+ }
+ if newEnd := stackMaskAR.End + hostarch.PageSize; newEnd > stackMaskAR.End {
+ stackMaskAR.End = newEnd
+ }
+ copyAR = pseg.Range().Intersect(stackMaskAR)
} else {
copyAR = pseg.Range().Intersect(maskAR)
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e347442e7..bf5ec4558 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 2f9462cee..8cf2f29e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -59,8 +59,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -2045,7 +2045,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
return syserr.ErrInvalidEndpointState
- } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial {
return syserr.ErrInvalidEndpointState
}
@@ -3331,10 +3331,10 @@ func (s *socketOpsCommon) State() uint32 {
}
case isUDPSocket(s.skType, s.protocol):
// UDP socket.
- switch udp.EndpointState(s.Endpoint.State()) {
- case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ switch transport.DatagramEndpointState(s.Endpoint.State()) {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed:
return linux.TCP_CLOSE
- case udp.StateConnected:
+ case transport.DatagramEndpointStateConnected:
return linux.TCP_ESTABLISHED
default:
return 0
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 658e90bb9..83b9d9389 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -750,6 +750,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: Ntohs(a.Protocol),
}, family, nil
case linux.AF_UNSPEC:
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e4de44498..a9cedcf5f 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -133,7 +133,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f
free := q.limit - q.used
if l > free && truncate {
- if free == 0 {
+ if free <= 0 {
// Message can't fit right now.
q.mu.Unlock()
return 0, false, syserr.ErrWouldBlock
diff --git a/pkg/shim/runtimeoptions/runtimeoptions.go b/pkg/shim/runtimeoptions/runtimeoptions.go
index 072dd87f0..e76d73ea7 100644
--- a/pkg/shim/runtimeoptions/runtimeoptions.go
+++ b/pkg/shim/runtimeoptions/runtimeoptions.go
@@ -15,3 +15,10 @@
// Package runtimeoptions contains the runtimeoptions proto.
package runtimeoptions
+
+import proto "github.com/gogo/protobuf/proto"
+
+func init() {
+ // TODO(gvisor.dev/issue/6449): Upgrade runtimeoptions.proto after upgrading to containerd 1.5
+ proto.RegisterType((*Options)(nil), "runtimeoptions.v1.Options")
+}
diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD
index 1cd5d641d..d8c4c9613 100644
--- a/pkg/syserr/BUILD
+++ b/pkg/syserr/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/abi/linux/errno",
"//pkg/errors",
"//pkg/errors/linuxerr",
+ "//pkg/safecopy",
"//pkg/tcpip",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go
index a5e386e38..b679f3046 100644
--- a/pkg/syserr/syserr.go
+++ b/pkg/syserr/syserr.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/safecopy"
)
// Error represents an internal error.
@@ -278,15 +279,18 @@ func FromError(err error) *Error {
if err == nil {
return nil
}
- if errno, ok := err.(unix.Errno); ok {
- return FromHost(errno)
- }
- if linuxErr, ok := err.(*errors.Error); ok {
- return FromHost(unix.Errno(linuxErr.Errno()))
+ switch e := err.(type) {
+ case unix.Errno:
+ return FromHost(e)
+ case *errors.Error:
+ return FromHost(unix.Errno(e.Errno()))
+ case safecopy.SegvError, safecopy.BusError, safecopy.AlignmentError:
+ return FromHost(unix.EFAULT)
}
- panic("unknown error: " + err.Error())
+ msg := fmt.Sprintf("err: %s type: %T", err.Error(), err)
+ panic(msg)
}
// ConvertIntr converts the provided error code (err) to another one (intr) if
diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD
new file mode 100644
index 000000000..9ae258a0b
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcp",
+ srcs = ["tcp.go"],
+ visibility = ["//pkg/tcpip:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go
new file mode 100644
index 000000000..0616d368c
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/tcp.go
@@ -0,0 +1,48 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains internal type definitions that are not expected to be
+// used by anyone else outside pkg/tcpip.
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TSOffset is an offset applied to the value of the TSVal field in the TCP
+// Timestamp option.
+//
+// +stateify savable
+type TSOffset struct {
+ milliseconds uint32
+}
+
+// NewTSOffset creates a new TSOffset from milliseconds.
+func NewTSOffset(milliseconds uint32) TSOffset {
+ return TSOffset{
+ milliseconds: milliseconds,
+ }
+}
+
+// TSVal applies the offset to now and returns the timestamp in milliseconds.
+func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 {
+ return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds
+}
+
+// Elapsed calculates the elapsed time given now and the echoed back timestamp.
+func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f26c857eb..d02eea93c 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -290,3 +290,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b9db273d0..8211a2031 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -112,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP
}
eth.Encode(&fields)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.Endpoint.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 48356c343..058242f96 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
}
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 7012d8829..ca1f9c08d 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,19 +76,8 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- // Construct data as the unparsed portion for the loopback packet.
- data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
- // Because we're immediately turning around and writing the packet back
- // to the rx path, we intentionally don't preserve the remote and local
- // link addresses from the stack.Route we're passed.
- newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: data,
- })
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
-
- return nil
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WriteRawPacket(pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
+ })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3e2a1aa94..844f5959b 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 3e816b0c7..14cb96d63 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -152,3 +152,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.child.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 5030b6ba1..3ed0aa3fe 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.
func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 40bd5560b..dc63e5fb0 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -228,3 +228,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.lower.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 30cf659b8..66efe6472 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -202,6 +202,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
eth.Encode(ethHdr)
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a95602aa5..13900205d 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -155,3 +155,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index a71400ee9..b0e4237bd 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe
return pkts.Len(), nil
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index 605e9ef8d..4d4d98caf 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade
func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index e0847e58a..6c42ab29b 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -85,6 +85,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/transport/tcpconntrack",
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 72f66441f..ccb69393b 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -342,6 +342,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p
return n, nil
}
+func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b854d868c..ddc1ddab6 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -734,10 +734,29 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.mu.RUnlock()
// Deliver to interested packet endpoints without holding NIC lock.
+ var packetEPPkt *PacketBuffer
deliverPacketEPs := func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketHost
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
+ })
+ // If a link header was populated in the original packet buffer, then
+ // populate it in the packet buffer we provide to packet endpoints as
+ // packet endpoints inspect link headers.
+ packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size())
+ packetEPPkt.PktType = tcpip.PacketHost
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
}
if protoEPs != nil {
protoEPs.forEach(deliverPacketEPs)
@@ -758,13 +777,30 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
+ var packetEPPkt *PacketBuffer
eps.forEach(func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketOutgoing
- // Add the link layer header as outgoing packets are intercepted
- // before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ })
+ // Add the link layer header as outgoing packets are intercepted before
+ // the link layer header is created and packet endpoints are interested
+ // in the link header.
+ n.LinkEndpoint.AddHeader(local, remote, protocol, packetEPPkt)
+ packetEPPkt.PktType = tcpip.PacketOutgoing
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
})
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index dfe2c886f..57b3348b2 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -846,6 +846,14 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link.
+ //
+ // If the link-layer has its own header, the payload must already include the
+ // header.
+ //
+ // WriteRawPacket takes ownership of the packet.
+ WriteRawPacket(*PacketBuffer) tcpip.Error
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 8e5c6edbf..cb741e540 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -72,7 +72,8 @@ type Stack struct {
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
- rawFactory RawFactory
+ rawFactory RawFactory
+ packetEndpointWriteSupported bool
demux *transportDemuxer
@@ -218,6 +219,10 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
+ // AllowPacketEndpointWrite determines if packet endpoints support write
+ // operations.
+ AllowPacketEndpointWrite bool
+
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by the stack secure RNG.
@@ -359,23 +364,24 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: seed,
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: randomGenerator,
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ packetEndpointWriteSupported: opts.AllowPacketEndpointWrite,
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: seed,
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: randomGenerator,
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -1653,9 +1659,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
+ pkt.NetworkProtocolNumber = netProto
return nic.WritePacketToRemote(remote, netProto, pkt)
}
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: payload,
+ })
+ pkt.NetworkProtocolNumber = proto
+ return nic.WriteRawPacket(pkt)
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1823,6 +1847,13 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
+// Seed returns a 32 bit value that can be used as a seed value.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) Seed() uint32 {
+ return s.seed
+}
+
// Rand returns a reference to a pseudo random generator that can be used
// to generate random numbers as required.
func (s *Stack) Rand() *rand.Rand {
@@ -1940,3 +1971,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto
return false
}
+
+// PacketEndpointWriteSupported returns true iff packet endpoints support write
+// operations.
+func (s *Stack) PacketEndpointWriteSupported() bool {
+ return s.packetEndpointWriteSupported
+}
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index 93ea83cdc..dc7289441 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -402,7 +403,7 @@ type TCPSndBufState struct {
type TCPEndpointStateInner struct {
// TSOffset is a randomized offset added to the value of the TSVal
// field in the timestamp option.
- TSOffset uint32
+ TSOffset tcp.TSOffset
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
new file mode 100644
index 000000000..af332ed91
--- /dev/null
+++ b/pkg/tcpip/transport/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "transport",
+ srcs = [
+ "datagram.go",
+ "transport.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go
new file mode 100644
index 000000000..dfce72c69
--- /dev/null
+++ b/pkg/tcpip/transport/datagram.go
@@ -0,0 +1,49 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// DatagramEndpointState is the state of a datagram-based endpoint.
+type DatagramEndpointState tcpip.EndpointState
+
+// The states a datagram-based endpoint may be in.
+const (
+ _ DatagramEndpointState = iota
+ DatagramEndpointStateInitial
+ DatagramEndpointStateBound
+ DatagramEndpointStateConnected
+ DatagramEndpointStateClosed
+)
+
+// String implements fmt.Stringer.
+func (s DatagramEndpointState) String() string {
+ switch s {
+ case DatagramEndpointStateInitial:
+ return "INITIAL"
+ case DatagramEndpointStateBound:
+ return "BOUND"
+ case DatagramEndpointStateConnected:
+ return "CONNECTED"
+ case DatagramEndpointStateClosed:
+ return "CLOSED"
+ default:
+ panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s))
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index f9a15efb2..00497bf07 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -329,6 +329,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
route = r
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
return 0, &tcpip.ErrBadBuffer{}
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
new file mode 100644
index 000000000..d10e3f13a
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/transport/udp:__pkg__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ ],
+)
+
+go_test(
+ name = "network_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ deps = [
+ ":network",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
new file mode 100644
index 000000000..0dce60d89
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -0,0 +1,722 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network provides facilities to support tcpip.Endpoints that operate
+// at the network layer or above.
+package network
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Endpoint is a datagram-based endpoint. It only supports sending datagrams to
+// a peer.
+//
+// +stateify savable
+type Endpoint struct {
+ // The following fields must only be set once then never changed.
+ stack *stack.Stack `state:"manual"`
+ ops *tcpip.SocketOptions
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be read from/written to atomically.
+ state uint32
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ info stack.TransportEndpointInfo
+ // owner is the owner of transmitted packets.
+ owner tcpip.PacketOwner
+ writeShutdown bool
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ connectedRoute *stack.Route `state:"manual"`
+ multicastMemberships map[multicastMembership]struct{}
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ ttl uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastTTL uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastAddr tcpip.Address
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ multicastNICID tcpip.NICID
+ // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+ // applied while sending packets. Defaults to 0 as on Linux.
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ sendTOS uint8
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+// Init initializes the endpoint.
+func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
+ if e.multicastMemberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ }
+
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ *e = Endpoint{
+ stack: s,
+ ops: ops,
+ netProto: netProto,
+ transProto: transProto,
+
+ state: uint32(transport.DatagramEndpointStateInitial),
+
+ info: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ effectiveNetProto: netProto,
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ }
+}
+
+// NetProto returns the network protocol the endpoint was initialized with.
+func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
+ return e.netProto
+}
+
+// setState sets the state of the endpoint.
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
+// State returns the state of the endpoint.
+func (e *Endpoint) State() transport.DatagramEndpointState {
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+}
+
+// Close cleans the endpoint's resources and leaves the endpoint in a closed
+// state.
+func (e *Endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ for mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ if e.connectedRoute != nil {
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+ }
+
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
+}
+
+// SetOwner sets the owner of transmitted packets.
+func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.owner = owner
+}
+
+func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 {
+ if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
+ return multicastTTL
+ }
+
+ if ttl == 0 {
+ return route.DefaultTTL()
+ }
+
+ return ttl
+}
+
+// WriteContext holds the context for a write.
+type WriteContext struct {
+ transProto tcpip.TransportProtocolNumber
+ route *stack.Route
+ ttl uint8
+ tos uint8
+ owner tcpip.PacketOwner
+}
+
+// Release releases held resources.
+func (c *WriteContext) Release() {
+ c.route.Release()
+ *c = WriteContext{}
+}
+
+// WritePacketInfo is the properties of a packet that may be written.
+type WritePacketInfo struct {
+ NetProto tcpip.NetworkProtocolNumber
+ LocalAddress, RemoteAddress tcpip.Address
+ MaxHeaderLength uint16
+ RequiresTXTransportChecksum bool
+}
+
+// PacketInfo returns the properties of a packet that will be written.
+func (c *WriteContext) PacketInfo() WritePacketInfo {
+ return WritePacketInfo{
+ NetProto: c.route.NetProto(),
+ LocalAddress: c.route.LocalAddress(),
+ RemoteAddress: c.route.RemoteAddress(),
+ MaxHeaderLength: c.route.MaxHeaderLength(),
+ RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
+ }
+}
+
+// WritePacket attempts to write the packet.
+func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+ pkt.Owner = c.owner
+
+ if headerIncluded {
+ return c.route.WriteHeaderIncludedPacket(pkt)
+ }
+
+ return c.route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: c.transProto,
+ TTL: c.ttl,
+ TOS: c.tos,
+ }, pkt)
+}
+
+// AcquireContextForWrite acquires a WriteContext.
+func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
+ }
+
+ if e.writeShutdown {
+ return WriteContext{}, &tcpip.ErrClosedForSend{}
+ }
+
+ route := e.connectedRoute
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return WriteContext{}, &tcpip.ErrDestinationRequired{}
+ }
+
+ route.Acquire()
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := opts.To.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
+ if e.info.BindNICID != 0 {
+ if nicID != 0 && nicID != e.info.BindNICID {
+ return WriteContext{}, &tcpip.ErrNoRoute{}
+ }
+
+ nicID = e.info.BindNICID
+ }
+
+ dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ if err != nil {
+ return WriteContext{}, err
+ }
+
+ route, _, err = e.connectRoute(nicID, dst, netProto)
+ if err != nil {
+ return WriteContext{}, err
+ }
+ }
+
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
+ route.Release()
+ return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
+ }
+
+ return WriteContext{
+ transProto: e.transProto,
+ route: route,
+ ttl: calculateTTL(route, e.ttl, e.multicastTTL),
+ tos: e.sendTOS,
+ owner: e.owner,
+ }, nil
+}
+
+// Disconnect disconnects the endpoint from its peer.
+func (e *Endpoint) Disconnect() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return
+ }
+
+ // Exclude ephemerally bound endpoints.
+ if e.info.BindNICID != 0 || e.info.ID.LocalAddress == "" {
+ e.info.ID = stack.TransportEndpointID{
+ LocalAddress: e.info.ID.LocalAddress,
+ }
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ } else {
+ e.info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+ }
+
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.info.ID.LocalAddress
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
+ if err != nil {
+ return nil, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Connect connects the endpoint to the address.
+func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ return nil
+ })
+}
+
+// ConnectAndThen connects the endpoint to the address and then calls the
+// provided function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the network protocol used to connect to the peer
+// and the source and destination addresses that will be used to send traffic to
+// the peer.
+func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicID := addr.NIC
+ switch e.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ if e.info.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != e.info.BindNICID {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ nicID = e.info.BindNICID
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+
+ id := stack.TransportEndpointID{
+ LocalAddress: e.info.ID.LocalAddress,
+ RemoteAddress: r.RemoteAddress(),
+ }
+ if e.State() == transport.DatagramEndpointStateInitial {
+ id.LocalAddress = r.LocalAddress()
+ }
+
+ if err := f(r.NetProto(), e.info.ID, id); err != nil {
+ return err
+ }
+
+ if e.connectedRoute != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.connectedRoute.Release()
+ }
+ e.connectedRoute = r
+ e.info.ID = id
+ e.info.RegisterNICID = nicID
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
+ return nil
+}
+
+// Shutdown shutsdown the endpoint.
+func (e *Endpoint) Shutdown() tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ e.writeShutdown = true
+ return nil
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
+
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
+}
+
+// Bind binds the endpoint to the address.
+func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
+ return nil
+ })
+}
+
+// BindAndThen binds the endpoint to the address and then calls the provided
+// function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the bound network protocol and address.
+func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.State() != transport.DatagramEndpointStateInitial {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
+ nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
+ if nicID == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if err := f(netProto, addr.Addr); err != nil {
+ return err
+ }
+
+ e.info.ID = stack.TransportEndpointID{
+ LocalAddress: addr.Addr,
+ }
+ e.info.BindNICID = nicID
+ e.info.RegisterNICID = nicID
+ e.info.BindAddr = addr.Addr
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ return nil
+}
+
+// GetLocalAddress returns the address that the endpoint is bound to.
+func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ addr := e.info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
+ addr = e.connectedRoute.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.info.RegisterNICID,
+ Addr: addr,
+ }
+}
+
+// GetRemoteAddress returns the address that the endpoint is connected to.
+func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return tcpip.FullAddress{}, false
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.connectedRoute.RemoteAddress(),
+ NIC: e.info.RegisterNICID,
+ }, true
+}
+
+// SetSockOptInt sets the socket option.
+func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// GetSockOptInt returns the socket option.
+func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ default:
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+// SetSockOpt sets the socket option.
+func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4MappedLocked(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if e.info.BindNICID != 0 && e.info.BindNICID != nic {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case *tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return &tcpip.ErrPortInUse{}
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToInsert] = struct{}{}
+
+ case *tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ delete(e.multicastMemberships, memToRemove)
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt returns the socket option.
+func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ default:
+ return &tcpip.ErrUnknownProtocolOption{}
+ }
+ return nil
+}
+
+// Info returns a copy of the endpoint info.
+func (e *Endpoint) Info() stack.TransportEndpointInfo {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
new file mode 100644
index 000000000..858007156
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -0,0 +1,56 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *Endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err))
+ }
+ }
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound:
+ if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress))
+ }
+ }
+ case transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ multicastLoop := e.ops.GetMulticastLoop()
+ e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ if err != nil {
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
new file mode 100644
index 000000000..2c43eb66a
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -0,0 +1,209 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+func TestEndpointStateTransitions(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+ )
+
+ data := buffer.View([]byte{1, 2, 4, 5})
+ v4Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(ipv4NICAddr),
+ checker.DstAddr(ipv4RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ v6Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(ipv6NICAddr),
+ checker.DstAddr(ipv6RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ expectedMaxHeaderLength uint16
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLocalAddr tcpip.Address
+ bindAddr tcpip.Address
+ expectedBoundAddr tcpip.Address
+ remoteAddr tcpip.Address
+ expectedRemoteAddr tcpip.Address
+ checker func(*testing.T, buffer.View)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: header.IPv4AllSystems,
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: ipv4RemoteAddr,
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
+ expectedNetProto: ipv6.ProtocolNumber,
+ expectedLocalAddr: ipv6NICAddr,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
+ remoteAddr: ipv6RemoteAddr,
+ expectedRemoteAddr: ipv6RemoteAddr,
+ checker: v6Checker,
+ },
+ {
+ name: "IPv4-mapped-IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: testutil.MustParse6("::ffff:e000:0001"),
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ })
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ if state := ep.State(); state != transport.DatagramEndpointStateInitial {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateBound {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); connected {
+ t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
+ }
+
+ connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
+ if err := ep.Connect(connectAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateConnected {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); !connected {
+ t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
+ } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
+ t.Errorf("remote address mismatch (-want +got):\n%s", diff)
+ }
+
+ ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
+ }
+ defer ctx.Release()
+ info := ctx.PacketInfo()
+ if diff := cmp.Diff(network.WritePacketInfo{
+ NetProto: test.expectedNetProto,
+ LocalAddress: test.expectedLocalAddr,
+ RemoteAddress: test.expectedRemoteAddr,
+ MaxHeaderLength: test.expectedMaxHeaderLength,
+ RequiresTXTransportChecksum: true,
+ }, info); diff != "" {
+ t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
+ }
+ if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(info.MaxHeaderLength),
+ Data: data.ToVectorisedView(),
+ }), false /* headerIncluded */); err != nil {
+ t.Fatalf("ctx.WritePacket(_, false): %s", err)
+ }
+ if pkt, ok := e.Read(); !ok {
+ t.Fatalf("expected packet to be read from link endpoint")
+ } else {
+ test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ }
+
+ ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateClosed {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 8e7bb6c6e..89b4720aa 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -207,8 +207,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- return 0, &tcpip.ErrInvalidOptionValue{}
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if !ep.stack.PacketEndpointWriteSupported() {
+ return 0, &tcpip.ErrNotSupported{}
+ }
+
+ ep.mu.Lock()
+ closed := ep.closed
+ nicID := ep.boundNIC
+ ep.mu.Unlock()
+ if closed {
+ return 0, &tcpip.ErrClosedForSend{}
+ }
+
+ var remote tcpip.LinkAddress
+ proto := ep.netProto
+ if to := opts.To; to != nil {
+ remote = tcpip.LinkAddress(to.Addr)
+
+ if n := to.NIC; n != 0 {
+ nicID = n
+ }
+
+ if p := to.Port; p != 0 {
+ proto = tcpip.NetworkProtocolNumber(p)
+ }
+ }
+
+ if nicID == 0 {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make(buffer.View, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
+
+ if err := func() tcpip.Error {
+ if ep.cooked {
+ return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView())
+ }
+ return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView())
+ }(); err != nil {
+ return 0, err
+ }
+ return int64(len(payloadBytes)), nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e729921db..5c688d286 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -34,17 +34,11 @@ func (p *packet) loadReceivedAt(nsec int64) {
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads packet.data field.
func (p *packet) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 55854ba59..3bf6c0a8f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -281,6 +281,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
payloadBytes := make([]byte, p.Len())
if _, err := io.ReadFull(p, payloadBytes); err != nil {
return nil, nil, nil, &tcpip.ErrBadBuffer{}
@@ -600,6 +601,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// We copy headers' underlying bytes because pkt.*Header may point to
// the middle of a slice, and another struct may point to the "outer"
// slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
var combinedVV buffer.VectorisedView
if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index c3922bbe5..5148fe157 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -68,6 +68,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index f8269efa6..03c9fafa1 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -606,14 +606,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
if opts.TS {
- // Create a barely-sufficient endpoint to calculate the TSVal.
- pseudoEndpoint := endpoint{
- TCPEndpointStateInner: stack.TCPEndpointStateInner{
- TSOffset: e.protocol.tsOffset(s.dstAddr, s.srcAddr),
- },
- stack: e.stack,
- }
- synOpts.TSVal = pseudoEndpoint.tsValNow()
+ offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr)
+ now := e.stack.Clock().NowMonotonic()
+ synOpts.TSVal = offset.TSVal(now)
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 0623ee8ed..d2b8f298f 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2913,7 +2913,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) {
}
func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 {
- return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + e.TSOffset
+ return e.TSOffset.TSVal(now)
}
func (e *endpoint) tsValNow() uint32 {
@@ -2921,7 +2921,7 @@ func (e *endpoint) tsValNow() uint32 {
}
func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
- return time.Duration(e.tsVal(now)-tsEcr) * time.Millisecond
+ return e.TSOffset.Elapsed(now, tsEcr)
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b0ffd2429..e4410ad93 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -158,7 +159,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
return stack.UnknownDestinationPacketHandled
}
-func (p *protocol) tsOffset(src, dst tcpip.Address) uint32 {
+func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset {
// Initialize a random tsOffset that will be added to the recentTS
// everytime the timestamp is sent when the Timestamp option is enabled.
//
@@ -173,7 +174,7 @@ func (p *protocol) tsOffset(src, dst tcpip.Address) uint32 {
// It never returns an error.
_, _ = h.Write([]byte(src))
_, _ = h.Write([]byte(dst))
- return h.Sum32()
+ return tcp.NewTSOffset(h.Sum32())
}
// replyWithReset replies to the given segment with a reset segment.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 9ce8fcae9..90e493978 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 90b74a2a7..bc8708a5b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2128,6 +2128,211 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+func TestSmallReceiveBufferReadiness(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+
+ const nicID = 1
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
+ }
+
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ }
+ if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ t.Fatalf("tcpip.NewSubnet failed: %s", err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
+ var listenerWQ waiter.Queue
+ listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer listener.Close()
+ listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
+ defer listenerWQ.EventUnregister(&listenerEntry)
+
+ if err := listener.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := listener.Listen(1); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ localAddress, err := listener.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ for i := 8; i > 0; i /= 2 {
+ size := int64(i << 10)
+ t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
+ var clientWQ waiter.Queue
+ client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer client.Close()
+ switch err := client.Connect(localAddress).(type) {
+ case nil:
+ t.Fatal("Connect returned nil error")
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ <-listenerCh
+ server, serverWQ, err := listener.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+ defer server.Close()
+
+ client.SocketOptions().SetReceiveBufferSize(size, true)
+ // Send buffer size doesn't seem to affect this test.
+ // server.SocketOptions().SetSendBufferSize(size, true)
+
+ clientEntry, clientCh := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
+ defer clientWQ.EventUnregister(&clientEntry)
+
+ serverEntry, serverCh := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
+ defer serverWQ.EventUnregister(&serverEntry)
+
+ var total int64
+ for {
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ continue
+ case *tcpip.ErrWouldBlock:
+ select {
+ case <-serverCh:
+ continue
+ case <-time.After(100 * time.Millisecond):
+ // Well and truly full.
+ t.Logf("send and receive queues are full")
+ }
+ default:
+ t.Fatalf("Write failed: %s", err)
+ }
+ break
+ }
+ t.Logf("wrote %d bytes in total", total)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ if err := func() error {
+ var total int64
+ defer t.Logf("wrote %d bytes in total", total)
+ for r.Len() != 0 {
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on server")
+ select {
+ case <-serverCh:
+ case <-time.After(time.Second):
+ if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
+ t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("server.Write failed: %s", err)
+ }
+ }
+ if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
+ return fmt.Errorf("server.Shutdown failed: %s", err)
+ }
+ t.Logf("server end shutdown done")
+ return nil
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ if err := func() error {
+ total := 0
+ defer t.Logf("read %d bytes in total", total)
+ for {
+ switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case nil:
+ t.Logf("read %d bytes", res.Count)
+ total += res.Count
+ t.Logf("read total %d bytes till now", total)
+ case *tcpip.ErrClosedForReceive:
+ return nil
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on client")
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
+ return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("client.Write failed: %s", err)
+ }
+ }
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+ })
+ }
+}
+
// Test the stack receive window advertisement on receiving segments smaller than
// segment overhead. It tests for the right edge of the window to not grow when
// the endpoint is not being read from.
diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go
new file mode 100644
index 000000000..4c2ae87f4
--- /dev/null
+++ b/pkg/tcpip/transport/transport.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport supports transport protocols.
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index cdc344ab7..5cc7a2886 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -35,6 +35,8 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 108580508..4b6bdc3be 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,8 @@
package udp
import (
+ "fmt"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -40,36 +42,6 @@ type udpPacket struct {
tos uint8
}
-// EndpointState represents the state of a UDP endpoint.
-type EndpointState tcpip.EndpointState
-
-// Endpoint states. Note that are represented in a netstack-specific manner and
-// may not be meaningful externally. Specifically, they need to be translated to
-// Linux's representation for these states if presented to userspace.
-const (
- _ EndpointState = iota
- StateInitial
- StateBound
- StateConnected
- StateClosed
-)
-
-// String implements fmt.Stringer.
-func (s EndpointState) String() string {
- switch s {
- case StateInitial:
- return "INITIAL"
- case StateBound:
- return "BOUND"
- case StateConnected:
- return "CONNECTING"
- case StateClosed:
- return "CLOSED"
- default:
- return "UNKNOWN"
- }
-}
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -79,7 +51,6 @@ func (s EndpointState) String() string {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
@@ -87,6 +58,10 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ // TODO(b/142022063): Add ability to save and restore per endpoint stats.
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -96,37 +71,19 @@ type endpoint struct {
rcvBufSize int
rcvClosed bool
- // The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- // state must be read/set using the EndpointState()/setEndpointState()
- // methods.
- state uint32
- route *stack.Route `state:"manual"`
- dstPort uint16
- ttl uint8
- multicastTTL uint8
- multicastAddr tcpip.Address
- multicastNICID tcpip.NICID
- portFlags ports.Flags
-
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- sendTOS uint8
-
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
-
- // multicastMemberships that need to be remvoed when the endpoint is
- // closed. Protected by the mu mutex.
- multicastMemberships map[multicastMembership]struct{}
+ readShutdown bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -136,55 +93,25 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
-}
-// +stateify savable
-type multicastMembership struct {
- nicID tcpip.NICID
- multicastAddr tcpip.Address
+ localPort uint16
+ remotePort uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.UDPProtocolNumber,
- },
+ stack: s,
waiterQueue: waiterQueue,
- // RFC 1075 section 5.4 recommends a TTL of 1 for membership
- // requests.
- //
- // RFC 5135 4.2.1 appears to assume that IGMP messages have a
- // TTL of 1.
- //
- // RFC 5135 Appendix A defines TTL=1: A multicast source that
- // wants its traffic to not traverse a router (e.g., leave a
- // home network) may find it useful to send traffic with IP
- // TTL=1.
- //
- // Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastMemberships: make(map[multicastMembership]struct{}),
- state: uint32(StateInitial),
- uniqueID: s.UniqueID(),
+ uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
-// setEndpointState updates the state of the endpoint to state atomically. This
-// method is unexported as the only place we should update the state is in this
-// package but we allow the state to be read freely without holding e.mu.
-//
-// Precondition: e.mu must be held to call this method.
-func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
-// EndpointState() returns the current state of the endpoint.
-func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32(&e.state))
-}
-
// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -244,16 +157,22 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.EndpointState() {
- case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ e.mu.Unlock()
+ return
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -261,13 +180,10 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- for mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
- }
- e.multicastMemberships = nil
-
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -278,14 +194,9 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.setEndpointState(StateClosed)
-
+ e.net.Shutdown()
+ e.net.Close()
+ e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
@@ -359,19 +270,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
return res, nil
}
-// prepareForWrite prepares the endpoint for sending data. In particular, it
-// binds it if it's still in the initial state. To do so, it must first
+// prepareForWriteInner prepares the endpoint for sending data. In particular,
+// it binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.EndpointState() {
- case StateInitial:
- case StateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
- case StateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -386,7 +297,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -398,33 +309,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
return true, nil
}
-// connectRoute establishes a route to the specified interface or the
-// configured multicast interface if no interface is specified and the
-// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.ID.LocalAddress
- if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
- // A packet can only originate from a unicast address (i.e., an interface).
- localAddr = ""
- }
-
- if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicID == 0 {
- nicID = e.multicastNICID
- }
- if localAddr == "" && nicID == 0 {
- localAddr = e.multicastAddr
- }
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
- if err != nil {
- return nil, 0, err
- }
- return r, nicID, nil
-}
-
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
@@ -448,18 +332,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
+func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(opts.To)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
return udpPacketInfo{}, err
}
@@ -469,49 +348,28 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
}
}
- route := e.route
- dstPort := e.dstPort
+ dst, connected := e.net.GetRemoteAddress()
+ dst.Port = e.remotePort
if opts.To != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := opts.To.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return udpPacketInfo{}, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
- if err != nil {
- return udpPacketInfo{}, err
- }
-
- r, _, err := e.connectRoute(nicID, dst, netProto)
- if err != nil {
- return udpPacketInfo{}, err
- }
- defer r.Release()
-
- route = r
- dstPort = dst.Port
+ dst = *opts.To
+ } else if !connected {
+ return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
}
- if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return udpPacketInfo{}, err
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
@@ -520,50 +378,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
- route.NetProto(),
+ e.net.NetProto(),
header.UDPMaximumPacketSize,
- tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress(),
- Port: dstPort,
- },
+ dst,
v,
)
}
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
- ttl := e.ttl
- useDefaultTTL := ttl == 0
- if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
- ttl = e.multicastTTL
- // Multicast allows a 0 TTL.
- useDefaultTTL = false
- }
-
return udpPacketInfo{
- route: route,
- data: buffer.View(v),
- localPort: e.ID.LocalPort,
- remotePort: dstPort,
- ttl: ttl,
- useDefaultTTL: useDefaultTTL,
- tos: e.sendTOS,
- owner: e.owner,
- noChecksum: e.SocketOptions().GetNoChecksum(),
+ ctx: ctx,
+ data: v,
+ localPort: e.localPort,
+ remotePort: dst.Port,
}, nil
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
// deadlock where the ICMP response handling ends up acquiring this endpoint's
@@ -574,15 +407,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- u, err := e.buildUDPPacketInfo(p, opts)
- if err != nil {
+
+ if err := e.LastError(); err != nil {
return 0, err
}
- n, err := u.send()
+
+ udpInfo, err := e.prepareForWrite(p, opts)
if err != nil {
return 0, err
}
- return int64(n), nil
+ defer udpInfo.ctx.Release()
+
+ pktInfo := udpInfo.ctx.PacketInfo()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
+ Data: udpInfo.data.ToVectorisedView(),
+ })
+
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
+
+ length := uint16(pkt.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: udpInfo.localPort,
+ DstPort: udpInfo.remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if pktInfo.RequiresTXTransportChecksum &&
+ (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
+ udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
+ header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
+ pkt.Data().AsRange().Checksum(),
+ )))
+ }
+ if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ e.stack.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
+ }
+
+ // Track count of packets sent.
+ e.stack.Stats().UDP.PacketsSent.Increment()
+ return int64(len(udpInfo.data)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -601,36 +472,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.MTUDiscoverOption:
- // Return not supported if the value is not disabling path
- // MTU discovery.
- if v != tcpip.PMTUDiscoveryDont {
- return &tcpip.ErrNotSupported{}
- }
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- e.multicastTTL = uint8(v)
- e.mu.Unlock()
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- }
-
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
@@ -642,145 +484,12 @@ func (e *endpoint) HasNIC(id int32) bool {
// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
-
- fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
- if err != nil {
- return err
- }
- nic := v.NIC
- addr := fa.Addr
-
- if nic == 0 && addr == "" {
- e.multicastAddr = ""
- e.multicastNICID = 0
- break
- }
-
- if nic != 0 {
- if !e.stack.CheckNIC(nic) {
- return &tcpip.ErrBadLocalAddress{}
- }
- } else {
- nic = e.stack.CheckLocalAddress(0, netProto, addr)
- if nic == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
- }
-
- if e.BindNICID != 0 && e.BindNICID != nic {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- e.multicastNICID = nic
- e.multicastAddr = addr
-
- case *tcpip.AddMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
-
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToInsert]; ok {
- return &tcpip.ErrPortInUse{}
- }
-
- if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- e.multicastMemberships[memToInsert] = struct{}{}
-
- case *tcpip.RemoveMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return &tcpip.ErrBadLocalAddress{}
- }
-
- if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- delete(e.multicastMemberships, memToRemove)
-
- case *tcpip.SocketDetachFilterOption:
- return nil
- }
- return nil
+ return e.net.SetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
- case tcpip.IPv4TOSOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.MTUDiscoverOption:
- // The only supported setting is path MTU discovery disabled.
- return tcpip.PMTUDiscoveryDont, nil
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- v := int(e.multicastTTL)
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -791,108 +500,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.mu.Lock()
- v := int(e.ttl)
- e.mu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- *o = tcpip.MulticastInterfaceOption{
- NIC: e.multicastNICID,
- InterfaceAddr: e.multicastAddr,
- }
- e.mu.Unlock()
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
- return nil
+ return e.net.GetSockOpt(opt)
}
-// udpPacketInfo contains all information required to send a UDP packet.
-//
-// This should be used as a value-only type, which exists in order to simplify
-// return value syntax. It should not be exported or extended.
+// udpPacketInfo holds information needed to send a UDP packet.
type udpPacketInfo struct {
- route *stack.Route
- data buffer.View
- localPort uint16
- remotePort uint16
- ttl uint8
- useDefaultTTL bool
- tos uint8
- owner tcpip.PacketOwner
- noChecksum bool
-}
-
-// send sends the given packet.
-func (u *udpPacketInfo) send() (int, tcpip.Error) {
- vv := u.data.ToVectorisedView()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
- Data: vv,
- })
- pkt.Owner = u.owner
-
- // Initialize the UDP header.
- udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- pkt.TransportProtocolNumber = ProtocolNumber
-
- length := uint16(pkt.Size())
- udp.Encode(&header.UDPFields{
- SrcPort: u.localPort,
- DstPort: u.remotePort,
- Length: length,
- })
-
- // Set the checksum field unless TX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value indicates the
- // transmitter skipped the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if u.route.RequiresTXTransportChecksum() &&
- (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
- xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udp.SetChecksum(^udp.CalculateChecksum(xsum))
- }
-
- if u.useDefaultTTL {
- u.ttl = u.route.DefaultTTL()
- }
- if err := u.route.WritePacket(stack.NetworkHeaderParams{
- Protocol: ProtocolNumber,
- TTL: u.ttl,
- TOS: u.tos,
- }, pkt); err != nil {
- u.route.Stats().UDP.PacketSendErrors.Increment()
- return 0, err
- }
-
- // Track count of packets sent.
- u.route.Stats().UDP.PacketsSent.Increment()
- return len(u.data), nil
-}
-
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
+ ctx network.WriteContext
+ data buffer.View
+ localPort uint16
+ remotePort uint16
}
// Disconnect implements tcpip.Endpoint.
@@ -900,7 +523,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.EndpointState() != StateConnected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return nil
}
var (
@@ -913,26 +536,28 @@ func (e *endpoint) Disconnect() tcpip.Error {
boundPortFlags := e.boundPortFlags
// Exclude ephemerally bound endpoints.
- if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ if info.BindNICID != 0 || info.ID.LocalAddress == "" {
var err tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.ID.LocalPort,
- LocalAddress: e.ID.LocalAddress,
+ LocalPort: info.ID.LocalPort,
+ LocalAddress: info.ID.LocalAddress,
}
id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
- if e.ID.LocalPort != 0 {
+ if info.ID.LocalPort != 0 {
// Release the ephemeral port.
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: info.ID.LocalAddress,
+ Port: info.ID.LocalPort,
Flags: boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -940,15 +565,14 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
- e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
- e.ID = id
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = btd
- e.route.Release()
- e.route = nil
- e.dstPort = 0
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+
+ e.net.Disconnect()
return nil
}
@@ -958,88 +582,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- var localPort uint16
- switch e.EndpointState() {
- case StateInitial:
- case StateBound, StateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
-
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.localPort
+ nextID.RemotePort = addr.Port
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: e.ID.LocalAddress,
- LocalPort: localPort,
- RemotePort: addr.Port,
- RemoteAddress: r.RemoteAddress(),
- }
+ oldPortFlags := e.boundPortFlags
- if e.EndpointState() == StateInitial {
- id.LocalAddress = r.LocalAddress()
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv4ProtocolNumber,
- header.IPv6ProtocolNumber,
+ nextID, btd, err := e.registerWithStack(netProtos, nextID)
+ if err != nil {
+ return err
}
- }
- oldPortFlags := e.boundPortFlags
+ // Remove the old registration.
+ if e.localPort != 0 {
+ previousID.LocalPort = e.localPort
+ previousID.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
+ }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = nextID.LocalPort
+ e.remotePort = nextID.RemotePort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- // Remove the old registration.
- if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
- }
-
- e.ID = id
- e.boundBindToDevice = btd
- if e.route != nil {
- // If the endpoint was already connected then make sure we release the
- // previous route.
- e.route.Release()
- }
- e.route = r
- e.dstPort = addr.Port
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- e.setEndpointState(StateConnected)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1054,15 +638,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // A socket in the bound state can still receive multicast messages,
- // so we need to notify waiters on shutdown.
- if state := e.EndpointState(); state != StateBound && state != StateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- e.shutdownFlags |= flags
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
+ }
if flags&tcpip.ShutdownRead != 0 {
+ e.readShutdown = true
+
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
@@ -1088,7 +680,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if e.ID.LocalPort == 0 {
+ if e.localPort == 0 {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
@@ -1126,56 +718,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
+ if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
}
- }
- nicID := addr.NIC
- if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
- // A local unicast address was specified, verify that it's valid.
- nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicID == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: boundAddr,
+ }
+ id, btd, err := e.registerWithStack(netProtos, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = id.LocalPort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.boundBindToDevice = btd
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- // Mark endpoint as bound.
- e.setEndpointState(StateBound)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1190,9 +769,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
return err
}
- // Save the effective NICID generated by bindLocked.
- e.BindNICID = e.RegisterNICID
-
return nil
}
@@ -1201,16 +777,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.ID.LocalAddress
- if e.EndpointState() == StateConnected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.localPort
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -1218,15 +787,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected || e.dstPort == 0 {
+ addr, connected := e.net.GetRemoteAddress()
+ if !connected || e.remotePort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ addr.Port = e.remotePort
+ return addr, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -1376,19 +943,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
payload = udp.Payload()
}
+ id := e.net.Info().ID
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Payload: payload,
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: id.RemoteAddress,
+ Port: e.remotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: e.localPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -1403,7 +971,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// TODO(gvisor.dev/issues/5270): Handle all transport errors.
switch transErr.Kind() {
case stack.DestinationPortUnreachableTransportError:
- if e.EndpointState() == StateConnected {
+ if e.net.State() == transport.DatagramEndpointStateConnected {
e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
}
}
@@ -1411,16 +979,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
- return uint32(e.EndpointState())
+ return uint32(e.net.State())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
- return &ret
+ defer e.mu.RUnlock()
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ return &info
}
// Stats returns a pointer to the endpoint stats.
@@ -1431,13 +1000,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
-func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
-}
-
// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// SocketOptions implements tcpip.Endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 1f638c3f6..2ff8b0482 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,12 +15,13 @@
package udp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) {
// saveData saves udpPacket.data field.
func (p *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
func (p *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
@@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.mu.Lock()
defer e.mu.Unlock()
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- for m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- state := e.EndpointState()
- if state != StateBound && state != StateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err tcpip.Error
- if state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ var err tcpip.Error
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
- // A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.ID
- e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
- if err != nil {
- panic(err)
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 7c357cb09..7238fc019 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
netHdr := r.pkt.Network()
- route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
+ if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
+ return nil, err
+ }
+
+ if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
return nil, err
}
- ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
- route.Release()
return nil, err
}
- ep.ID = r.id
- ep.route = route
- ep.dstPort = r.id.RemotePort
+ ep.localPort = r.id.LocalPort
+ ep.remotePort = r.id.RemotePort
ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
- ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = uint32(StateConnected)
-
ep.rcvMu.Lock()
ep.rcvReady = true
ep.rcvMu.Unlock()
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 3f667cd74..1dd0048ac 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -1089,13 +1089,14 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st
return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
case config.NetworkNone, config.NetworkSandbox:
- s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID, conf.AllowPacketEndpointWrite)
if err != nil {
return nil, err
}
creator := &sandboxNetstackCreator{
- clock: clock,
- uniqueID: uniqueID,
+ clock: clock,
+ uniqueID: uniqueID,
+ allowPacketEndpointWrite: conf.AllowPacketEndpointWrite,
}
return inet.NewRootNamespace(s, creator), nil
@@ -1105,7 +1106,7 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st
}
-func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID, allowPacketEndpointWrite bool) (inet.Stack, error) {
netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}
transProtos := []stack.TransportProtocolFactory{
tcp.NewProtocol,
@@ -1121,9 +1122,10 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in
HandleLocal: true,
// Enable raw sockets for users with sufficient
// privileges.
- RawFactory: raw.EndpointFactory{},
- UniqueID: uniqueID,
- DefaultIPTables: netfilter.DefaultLinuxTables,
+ RawFactory: raw.EndpointFactory{},
+ AllowPacketEndpointWrite: allowPacketEndpointWrite,
+ UniqueID: uniqueID,
+ DefaultIPTables: netfilter.DefaultLinuxTables,
})}
// Enable SACK Recovery.
@@ -1160,13 +1162,14 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in
//
// +stateify savable
type sandboxNetstackCreator struct {
- clock tcpip.Clock
- uniqueID stack.UniqueID
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+ allowPacketEndpointWrite bool
}
// CreateStack implements kernel.NetworkStackCreator.CreateStack.
func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
- s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID, f.allowPacketEndpointWrite)
if err != nil {
return nil, err
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index 2f52863ff..2ce8cc006 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -86,6 +86,9 @@ type Config struct {
// capabilities.
EnableRaw bool `flag:"net-raw"`
+ // AllowPacketEndpointWrite enables write operations on packet endpoints.
+ AllowPacketEndpointWrite bool `flag:"TESTONLY-allow-packet-endpoint-write"`
+
// HardwareGSO indicates that hardware segmentation offload is enabled.
HardwareGSO bool `flag:"gso"`
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 85507902a..cc5aba474 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -92,6 +92,7 @@ func RegisterFlags() {
// Test flags, not to be used outside tests, ever.
flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.")
+ flag.Bool("TESTONLY-allow-packet-endpoint-write", false, "TEST ONLY; do not ever use! Used for tests to allow writes on packet sockets.")
})
}
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index ea83bbe72..49f41c887 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -28,10 +28,10 @@
#include <iostream>
#include <unordered_map>
+#include "absl/strings/str_format.h"
#include "include/grpcpp/security/server_credentials.h"
#include "include/grpcpp/server_builder.h"
#include "include/grpcpp/server_context.h"
-#include "absl/strings/str_format.h"
#include "test/packetimpact/proto/posix_server.grpc.pb.h"
#include "test/packetimpact/proto/posix_server.pb.h"
diff --git a/test/runner/main.go b/test/runner/main.go
index 34e9c6279..2cab8d2d4 100644
--- a/test/runner/main.go
+++ b/test/runner/main.go
@@ -170,6 +170,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
"-network", *network,
"-log-format=text",
"-TESTONLY-unsafe-nonroot=true",
+ "-TESTONLY-allow-packet-endpoint-write=true",
"-net-raw=true",
fmt.Sprintf("-panic-signal=%d", unix.SIGTERM),
"-watchdog-action=panic",
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 01ee432cb..b06b3d233 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -3293,9 +3293,12 @@ cc_library(
],
deps = [
":unix_domain_socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:memory_util",
"//test/util:socket_util",
"@com_google_absl//absl/strings",
gtest,
+ "//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
],
diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc
index 8202d35fa..a2cc59e83 100644
--- a/test/syscalls/linux/eventfd.cc
+++ b/test/syscalls/linux/eventfd.cc
@@ -149,31 +149,6 @@ TEST(EventfdTest, BigWriteBigRead) {
EXPECT_EQ(l[0], 1);
}
-TEST(EventfdTest, SpliceFromPipePartialSucceeds) {
- int pipes[2];
- ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
- const FileDescriptor pipe_rfd(pipes[0]);
- const FileDescriptor pipe_wfd(pipes[1]);
- constexpr uint64_t kVal{1};
-
- FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK));
-
- uint64_t event_array[2];
- event_array[0] = kVal;
- event_array[1] = kVal;
- ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)),
- SyscallSucceedsWithValue(sizeof(event_array)));
- EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(),
- /*__offout=*/nullptr, sizeof(event_array[0]) + 1,
- SPLICE_F_NONBLOCK),
- SyscallSucceedsWithValue(sizeof(event_array[0])));
-
- uint64_t val;
- ASSERT_THAT(read(efd.get(), &val, sizeof(val)),
- SyscallSucceedsWithValue(sizeof(val)));
- EXPECT_EQ(val, kVal);
-}
-
// NotifyNonZero is inherently racy, so random save is disabled.
TEST(EventfdTest, NotifyNonZero) {
// Waits will time out at 10 seconds.
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index f6b78989b..e2622232d 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -1849,34 +1849,6 @@ TEST(Inotify, SpliceOnWatchTarget) {
}));
}
-TEST(Inotify, SpliceOnInotifyFD) {
- int pipefds[2];
- ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds());
-
- const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- const FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
- const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
- root.path(), "some content", TempPath::kDefaultFileMode));
-
- const FileDescriptor file1_fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
- const int watcher = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
-
- char buf;
- EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
-
- EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr,
- sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
- SyscallSucceedsWithValue(sizeof(struct inotify_event)));
-
- const FileDescriptor read_fd(pipefds[0]);
- const std::vector<Event> events =
- ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
- ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
-}
-
// Watches on a parent should not be triggered by actions on a hard link to one
// of its children that has a different parent.
TEST(Inotify, LinkOnOtherParent) {
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index ca4ab0aad..bfa5d179a 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -285,14 +285,6 @@ TEST_P(CookedPacketTest, Send) {
memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr));
memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage));
- // We don't implement writing to packet sockets on gVisor.
- if (IsRunningOnGvisor()) {
- ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallFailsWithErrno(EINVAL));
- GTEST_SKIP();
- }
-
// Send it.
ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 61714d1da..e57c60ffa 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -296,14 +296,6 @@ TEST_P(RawPacketTest, Send) {
memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage,
sizeof(kMessage));
- // We don't implement writing to packet sockets on gVisor.
- if (IsRunningOnGvisor()) {
- ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallFailsWithErrno(EINVAL));
- GTEST_SKIP();
- }
-
// Send it.
ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0,
reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index bea4ee71c..9bd3bd5e8 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -208,38 +208,6 @@ TEST(SendFileTest, SendAndUpdateFileOffset) {
absl::string_view(actual, kHalfDataSize));
}
-TEST(SendFileTest, SendToDevZeroAndUpdateFileOffset) {
- // Create temp files.
- // Test input string length must be > 2 AND even.
- constexpr char kData[] = "The slings and arrows of outrageous fortune,";
- constexpr int kDataSize = sizeof(kData) - 1;
- constexpr int kHalfDataSize = kDataSize / 2;
- const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
- GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
-
- // Open the input file as read only.
- const FileDescriptor inf =
- ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
-
- // Open /dev/zero as write only.
- const FileDescriptor outf =
- ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY));
-
- // Send data and verify that sendfile returns the correct value.
- int bytes_sent;
- EXPECT_THAT(
- bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize),
- SyscallSucceedsWithValue(kHalfDataSize));
-
- char actual[kHalfDataSize];
- // Verify that the input file offset has been updated.
- ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent),
- SyscallSucceedsWithValue(kHalfDataSize));
- EXPECT_EQ(
- absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent),
- absl::string_view(actual, kHalfDataSize));
-}
-
TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) {
// Create temp files.
// Test input string length must be > 2 AND divisible by 4.
@@ -609,23 +577,6 @@ TEST(SendFileTest, SendPipeBlocks) {
SyscallSucceedsWithValue(kDataSize));
}
-TEST(SendFileTest, SendToSpecialFile) {
- // Create temp file.
- const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
- GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
-
- const FileDescriptor inf =
- ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
- constexpr int kSize = 0x7ff;
- ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds());
-
- auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
-
- // eventfd can accept a number of bytes which is a multiple of 8.
- EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff),
- SyscallSucceedsWithValue(kSize & (~7)));
-}
-
TEST(SendFileTest, SendFileToPipe) {
// Create temp file.
constexpr char kData[] = "<insert-quote-here>";
@@ -672,57 +623,6 @@ TEST(SendFileTest, SendFileToSelf) {
SyscallSucceedsWithValue(kSendfileSize));
}
-static volatile int signaled = 0;
-void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
-
-TEST(SendFileTest, ToEventFDDoesNotSpin) {
- FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
-
- // Write the maximum value of an eventfd to a file.
- const uint64_t kMaxEventfdValue = 0xfffffffffffffffe;
- const auto tempfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const auto tempfd = ASSERT_NO_ERRNO_AND_VALUE(Open(tempfile.path(), O_RDWR));
- ASSERT_THAT(
- pwrite(tempfd.get(), &kMaxEventfdValue, sizeof(kMaxEventfdValue), 0),
- SyscallSucceedsWithValue(sizeof(kMaxEventfdValue)));
-
- // Set the eventfd's value to 1.
- const uint64_t kOne = 1;
- ASSERT_THAT(write(efd.get(), &kOne, sizeof(kOne)),
- SyscallSucceedsWithValue(sizeof(kOne)));
-
- // Set up signal handler.
- struct sigaction sa = {};
- sa.sa_sigaction = SigUsr1Handler;
- sa.sa_flags = SA_SIGINFO;
- const auto cleanup_sigact =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa));
-
- // Send SIGUSR1 to this thread in 1 second.
- struct sigevent sev = {};
- sev.sigev_notify = SIGEV_THREAD_ID;
- sev.sigev_signo = SIGUSR1;
- sev.sigev_notify_thread_id = gettid();
- auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
- struct itimerspec its = {};
- its.it_value = absl::ToTimespec(absl::Seconds(1));
- DisableSave ds; // Asserting an EINTR.
- ASSERT_NO_ERRNO(timer.Set(0, its));
-
- // Sendfile from tempfd to the eventfd. Since the eventfd is not already at
- // its maximum value, the eventfd is "ready for writing"; however, since the
- // eventfd's existing value plus the new value would exceed the maximum, the
- // write should internally fail with EWOULDBLOCK. In this case, sendfile()
- // should block instead of spinning, and eventually be interrupted by our
- // timer. See b/172075629.
- EXPECT_THAT(
- sendfile(efd.get(), tempfd.get(), nullptr, sizeof(kMaxEventfdValue)),
- SyscallFailsWithErrno(EINTR));
-
- // Signal should have been handled.
- EXPECT_EQ(signaled, 1);
-}
-
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc
index cf96b2075..43433eaae 100644
--- a/test/syscalls/linux/socket_unix.cc
+++ b/test/syscalls/linux/socket_unix.cc
@@ -27,7 +27,10 @@
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/memory_util.h"
#include "test/util/socket_util.h"
+#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -268,6 +271,18 @@ TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) {
}
}
+// Repro for b/196804997.
+TEST_P(UnixSocketPairTest, SendFromMmapBeyondEof) {
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(nullptr, kPageSize, PROT_READ, MAP_SHARED, fd.get(), 0));
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(send(sockets->first_fd(), m.ptr(), m.len(), 0),
+ SyscallFailsWithErrno(EFAULT));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 6e9f70f8c..2f3cfc3f3 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -181,6 +181,21 @@ TEST_P(StreamUnixSocketPairTest, SetSocketSendBuf) {
ASSERT_EQ(quarter_sz, val);
}
+TEST_P(StreamUnixSocketPairTest, SendBufferOverflow) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ auto s = sockets->first_fd();
+
+ constexpr int kBufSz = 4096;
+ std::vector<char> buf(kBufSz * 4);
+ ASSERT_THAT(RetryEINTR(send)(s, buf.data(), buf.size(), MSG_DONTWAIT),
+ SyscallSucceeds());
+ // The new buffer size should be smaller that the amount of data in the queue.
+ ASSERT_THAT(setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kBufSz, sizeof(kBufSz)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(send)(s, buf.data(), buf.size(), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
TEST_P(StreamUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int sock = sockets->first_fd();
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index c85f6da0b..4a10ae8d2 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -195,81 +195,6 @@ TEST(SpliceTest, PipeOffsets) {
SyscallFailsWithErrno(ESPIPE));
}
-// Event FDs may be used with splice without an offset.
-TEST(SpliceTest, FromEventFD) {
- // Open the input eventfd with an initial value so that it is readable.
- constexpr uint64_t kEventFDValue = 1;
- int efd;
- ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds());
- const FileDescriptor in_fd(efd);
-
- // Create a new pipe.
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
- const FileDescriptor rfd(fds[0]);
- const FileDescriptor wfd(fds[1]);
-
- // Splice 8-byte eventfd value to pipe.
- constexpr int kEventFDSize = 8;
- EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
- SyscallSucceedsWithValue(kEventFDSize));
-
- // Contents should be equal.
- std::vector<char> rbuf(kEventFDSize);
- ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
- SyscallSucceedsWithValue(kEventFDSize));
- EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0);
-}
-
-// Event FDs may not be used with splice with an offset.
-TEST(SpliceTest, FromEventFDOffset) {
- int efd;
- ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor in_fd(efd);
-
- // Create a new pipe.
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
- const FileDescriptor rfd(fds[0]);
- const FileDescriptor wfd(fds[1]);
-
- // Attempt to splice 8-byte eventfd value to pipe with offset.
- //
- // This is not allowed because eventfd doesn't support pread.
- constexpr int kEventFDSize = 8;
- loff_t in_off = 0;
- EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
- SyscallFailsWithErrno(EINVAL));
-}
-
-// Event FDs may not be used with splice with an offset.
-TEST(SpliceTest, ToEventFDOffset) {
- // Create a new pipe.
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
- const FileDescriptor rfd(fds[0]);
- const FileDescriptor wfd(fds[1]);
-
- // Fill with a value.
- constexpr int kEventFDSize = 8;
- std::vector<char> buf(kEventFDSize);
- buf[0] = 1;
- ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
- SyscallSucceedsWithValue(kEventFDSize));
-
- int efd;
- ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor out_fd(efd);
-
- // Attempt to splice 8-byte eventfd value to pipe with offset.
- //
- // This is not allowed because eventfd doesn't support pwrite.
- loff_t out_off = 0;
- EXPECT_THAT(
- splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0),
- SyscallFailsWithErrno(EINVAL));
-}
-
TEST(SpliceTest, ToPipe) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
@@ -852,34 +777,6 @@ TEST(SpliceTest, FromPipeMaxFileSize) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
}
-TEST(SpliceTest, FromPipeToDevZero) {
- // Create a new pipe.
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
- const FileDescriptor rfd(fds[0]);
- FileDescriptor wfd(fds[1]);
-
- // Fill with some random data.
- std::vector<char> buf(kPageSize);
- RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
- SyscallSucceedsWithValue(kPageSize));
-
- const FileDescriptor zero =
- ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY));
-
- // Close the write end to prevent blocking below.
- wfd.reset();
-
- // Splice to /dev/zero. The first call should empty the pipe, and the return
- // value should not exceed the number of bytes available for reading.
- EXPECT_THAT(
- splice(rfd.get(), nullptr, zero.get(), nullptr, kPageSize + 123, 0),
- SyscallSucceedsWithValue(kPageSize));
- EXPECT_THAT(splice(rfd.get(), nullptr, zero.get(), nullptr, 1, 0),
- SyscallSucceedsWithValue(0));
-}
-
static volatile int signaled = 0;
void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
diff --git a/test/util/BUILD b/test/util/BUILD
index b92af1c27..2dcf71613 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "default_net_util", "gbenchmark", "gtest", "select_system")
+load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "default_net_util", "gbenchmark_internal", "gtest", "select_system")
package(
default_visibility = ["//:sandbox"],
@@ -295,7 +295,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
gtest,
- gbenchmark,
+ gbenchmark_internal,
],
)
diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl
index 2831eac5f..57d33726a 100644
--- a/tools/bazeldefs/cc.bzl
+++ b/tools/bazeldefs/cc.bzl
@@ -9,6 +9,7 @@ cc_test = _cc_test
cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
gtest = "@com_google_googletest//:gtest"
gbenchmark = "@com_google_benchmark//:benchmark"
+gbenchmark_internal = "@com_google_benchmark//:benchmark"
grpcpp = "@com_github_grpc_grpc//:grpc++"
vdso_linker_option = "-fuse-ld=gold "
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 27542a2f5..f4266e1de 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -9,7 +9,7 @@ load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
load("//tools/nogo:defs.bzl", "nogo_test")
load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _version = "version")
-load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option")
+load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _gbenchmark_internal = "gbenchmark_internal", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option")
load("//tools/bazeldefs:go.bzl", _bazel_worker_proto = "bazel_worker_proto", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_rule = "go_rule", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos")
load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
@@ -37,6 +37,7 @@ cc_library = _cc_library
cc_test = _cc_test
cc_toolchain = _cc_toolchain
gbenchmark = _gbenchmark
+gbenchmark_internal = _gbenchmark_internal
gtest = _gtest
grpcpp = _grpcpp
vdso_linker_option = _vdso_linker_option
diff --git a/website/_config.yml b/website/_config.yml
index dc44945bc..5f1cbbdeb 100644
--- a/website/_config.yml
+++ b/website/_config.yml
@@ -44,3 +44,6 @@ authors:
mpratt:
name: Michael Pratt
email: mpratt@google.com
+ nybidari:
+ name: Nayana Bidari
+ email: nybidari@google.com
diff --git a/website/assets/images/2021-08-31-rack-figure1.png b/website/assets/images/2021-08-31-rack-figure1.png
new file mode 100644
index 000000000..6d9fdb147
--- /dev/null
+++ b/website/assets/images/2021-08-31-rack-figure1.png
Binary files differ
diff --git a/website/assets/images/2021-08-31-rack-figure2.png b/website/assets/images/2021-08-31-rack-figure2.png
new file mode 100644
index 000000000..c2043ecae
--- /dev/null
+++ b/website/assets/images/2021-08-31-rack-figure2.png
Binary files differ
diff --git a/website/assets/images/2021-08-31-rack-figure3.png b/website/assets/images/2021-08-31-rack-figure3.png
new file mode 100644
index 000000000..e8b689f33
--- /dev/null
+++ b/website/assets/images/2021-08-31-rack-figure3.png
Binary files differ
diff --git a/website/blog/2021-08-31-gvisor-rack.md b/website/blog/2021-08-31-gvisor-rack.md
new file mode 100644
index 000000000..e7d4582e4
--- /dev/null
+++ b/website/blog/2021-08-31-gvisor-rack.md
@@ -0,0 +1,120 @@
+# gVisor RACK
+
+gVisor has implemented the [RACK](https://datatracker.ietf.org/doc/html/rfc8985)
+(Recent ACKnowledgement) TCP loss-detection algorithm in our network stack,
+which improves throughput in the presence of packet loss and reordering.
+
+TCP is a connection-oriented protocol that detects and recovers from loss by
+retransmitting packets. [RACK](https://datatracker.ietf.org/doc/html/rfc8985) is
+one of the recent loss-detection methods implemented in Linux and BSD, which
+helps in identifying packet loss quickly and accurately in the presence of
+packet reordering and tail losses.
+
+## Background
+
+The TCP congestion window indicates the number of unacknowledged packets that
+can be sent at any time. When packet loss is identified, the congestion window
+is reduced depending on the type of loss. The sender will recover from the loss
+after all the packets sent before reducing the congestion window are
+acknowledged. If the loss is identified falsely by the connection, then the
+connection enters loss recovery unnecessarily, resulting in sending fewer
+packets.
+
+Packet loss is identified mainly in two ways:
+
+1. Three duplicate acknowledgments, which will result in either
+ [Fast](https://datatracker.ietf.org/doc/html/rfc2001#section-4) or
+ [SACK](https://datatracker.ietf.org/doc/html/rfc6675) recovery. The
+ congestion window is reduced depending on the type of congestion control
+ algorithm. For example, in the
+ [Reno](https://en.wikipedia.org/wiki/TCP_congestion_control#TCP_Tahoe_and_Reno)
+ algorithm it is reduced to half.
+2. RTO (Retransmission Timeout) which will result in Timeout recovery. The
+ congestion window is reduced to one
+ [MSS](https://en.wikipedia.org/wiki/Maximum_segment_size).
+
+Both of these cases result in reducing the congestion window, with RTO being
+more expensive. Most of the existing algorithms do not detect packet reordering,
+which get incorrectly identified as packet loss, resulting in an RTO.
+Furthermore, the loss of an ACK at the end of a sequence (known as "tail loss")
+will also trigger RTO and slow down future transmissions unnecessarily. RACK
+helps us to identify loss accurately in all these scenarios, and will avoid
+entering RTO.
+
+## Implementation of RACK
+
+Implementation of RACK requires support for:
+
+1. Per-packet transmission timestamps: RACK detects loss depending on the
+ transmission times of the packet and the timestamp at which ACK was
+ received.
+2. SACK and ability to detect DSACK: Selective Acknowledgement and Duplicate
+ SACK are used to adjust the timer window after which a packet can be marked
+ as lost.
+
+### Packet Reordering
+
+Packet reordering commonly occurs when different packets take different paths
+through a network. The diagram below shows the transmission of four packets
+which get reordered in transmission, and the resulting TCP behavior with and
+without RACK.
+
+![Figure 1](/assets/images/2021-08-31-rack-figure1.png "Packet reordering.")
+
+In the above example, the sender sees three duplicate acknowledgments. Without
+RACK, this is identified falsely as packet loss, and the congestion window will
+be reduced after entering Fast/SACK recovery.
+
+To detect packet reordering, RACK uses a reorder window, bounded between
+[[RTT](https://en.wikipedia.org/wiki/Round-trip_delay)/4, RTT]. The reorder
+timer is set to expire after _RTT+reorder\_window_. A packet is marked as lost
+when the packets following it were acknowledged using SACK and the reorder timer
+expires. The reorder window is increased when a DSACK is received (which
+indicates that there is a higher degree of reordering).
+
+### Tail Loss
+
+Tail loss occurs when the packets are lost at the end of data transmission. The
+diagram below shows an example of tail loss when the last three packets are
+lost, and how it is handled with and without RACK.
+
+![Figure 2](/assets/images/2021-08-31-rack-figure2.png "Tail loss figure 2.")
+
+For tail losses, RACK uses a Tail Loss Probe (TLP), which relies on a timer for
+the last packet sent. The TLP timer is set to _2 \* RTT,_ after which a probe is
+sent. The probe packet will allow the connection one more chance to detect a
+loss by triggering ACK feedback to avoid entering RTO. In the above example, the
+loss is recovered without entering the RTO.
+
+TLP will also help in cases where the ACK was lost but all the packets were
+received by the receiver. The below diagram shows that the ACK received for the
+probe packet avoided the RTO.
+
+![Figure 3](/assets/images/2021-08-31-rack-figure3.png "Tail loss figure 3.")
+
+If there was some loss, then the ACK for the probe packet will have the SACK
+blocks, which will be used to detect and retransmit the lost packets.
+
+In gVisor, we have support for
+[NewReno](https://datatracker.ietf.org/doc/html/rfc6582) and SACK loss recovery
+methods. We
+[added support for RACK](https://github.com/google/gvisor/issues/5243) recently,
+and it is the default when SACK is enabled. After enabling RACK, our internal
+benchmarks in the presence of reordering and tail losses and the data we took
+from internal users inside Google have shown ~50% reduction in the number of
+RTOs.
+
+While RACK has improved one aspect of TCP performance by reducing the timeouts
+in the presence of reordering and tail losses, in gVisor we plan to implement
+the undoing of congestion windows and
+[BBRv2](https://datatracker.ietf.org/doc/html/draft-cardwell-iccrg-bbr-congestion-control)
+(once there is an RFC available) to further improve TCP performance in less
+ideal network conditions.
+
+If you haven’t already, try gVisor. The instructions to get started are in our
+[Quick Start](https://gvisor.dev/docs/user_guide/quick_start/docker/). You can
+also get involved with the gVisor community via our
+[Gitter channel](https://gitter.im/gvisor/community),
+[email list](https://groups.google.com/forum/#!forum/gvisor-users),
+[issue tracker](https://gvisor.dev/issue/new), and
+[Github repository](https://github.com/google/gvisor).
diff --git a/website/blog/BUILD b/website/blog/BUILD
index 17beb721f..0384b9ba9 100644
--- a/website/blog/BUILD
+++ b/website/blog/BUILD
@@ -49,6 +49,16 @@ doc(
permalink = "/blog/2020/10/22/platform-portability/",
)
+doc(
+ name = "gvisor-rack",
+ src = "2021-08-31-gvisor-rack.md",
+ authors = [
+ "nybidari",
+ ],
+ layout = "post",
+ permalink = "/blog/2021/08/31/gvisor-rack/",
+)
+
docs(
name = "posts",
deps = [