diff options
306 files changed, 12496 insertions, 5601 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index cddf5504b..cb272aef6 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -45,6 +45,14 @@ steps: - make BAZEL_OPTIONS=--config=cross-aarch64 artifacts/aarch64 - make release + # Images tests. + - <<: *common + label: ":docker: Images (x86_64)" + command: make ARCH=x86_64 load-all-images + - <<: *common + label: ":docker: Images (aarch64)" + command: make ARCH=aarch64 load-all-images + # Basic unit tests. - <<: *common label: ":test_tube: Unit tests" @@ -192,7 +200,7 @@ steps: command: make benchmark-platforms BENCHMARKS_SUITE=node BENCHMARKS_TARGETS=test/benchmarks/network:node_test - <<: *benchmarks label: ":redis: Redis benchmarks" - command: make benchmark-platforms BENCHMARKS_SUITE=redis BENCHMARKS_TARGETS=test/benchmarks/database:redis_test + command: make benchmark-platforms BENCHMARKS_SUITE=redis BENCHMARKS_TARGETS=test/benchmarks/database:redis_test BENCHMARKS_OPTIONS=-test.benchtime=15s - <<: *benchmarks label: ":ruby: Ruby benchmarks" command: make benchmark-platforms BENCHMARKS_SUITE=ruby BENCHMARKS_TARGETS=test/benchmarks/network:ruby_test diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md index 69033357c..8e5721ad1 100644 --- a/g3doc/user_guide/FAQ.md +++ b/g3doc/user_guide/FAQ.md @@ -137,9 +137,16 @@ sandbox isolation. There are a few different workarounds you can try: * Use IPs instead of container names. * Use [Kubernetes][k8s]. Container name lookup works fine in Kubernetes. +### I'm getting an error like `dial unix /run/containerd/s/09e4...8cff: connect: connection refused: unknown` {#shim-connect} + +This error may happen when using `gvisor-containerd-shim` with a `containerd` +that does not contain the fix for [CVE-2020-15257]. The resolve the issue, +update containerd to 1.3.9 or 1.4.3 (or newer versions respectively). + [security-model]: /docs/architecture_guide/security/ [host-net]: /docs/user_guide/networking/#network-passthrough [debugging]: /docs/user_guide/debugging/ [filesystem]: /docs/user_guide/filesystem/ [docker]: /docs/user_guide/quick_start/docker/ [k8s]: /docs/user_guide/quick_start/kubernetes/ +[CVE-2020-15257]: https://github.com/containerd/containerd/security/advisories/GHSA-36xw-fx78-c5r4 diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md index 4f5e721be..011af3b10 100644 --- a/g3doc/user_guide/containerd/configuration.md +++ b/g3doc/user_guide/containerd/configuration.md @@ -1,8 +1,8 @@ # Containerd Advanced Configuration This document describes how to configure runtime options for -`containerd-shim-runsc-v1`. This follows the -[Containerd Quick Start](./quick_start.md) and requires containerd 1.2 or later. +`containerd-shim-runsc-v1`. You can find the installation instructions and +minimal requirements in [Containerd Quick Start](./quick_start.md). ## Shim Configuration @@ -47,27 +47,6 @@ When you are done, restart containerd to pick up the changes. sudo systemctl restart containerd ``` -### Containerd 1.2 - -For containerd 1.2, the config file is not configurable. It should be named -`config.toml` and located in the runtime root. By default, this is -`/run/containerd/runsc`. - -### Example: Enable the KVM platform - -gVisor enables the use of a number of platforms. This example shows how to -configure `containerd-shim-runsc-v1` to use gvisor with the KVM platform. - -Find out more about platform in the -[Platforms Guide](../../architecture_guide/platforms.md). - -```shell -cat <<EOF | sudo tee /etc/containerd/runsc.toml -[runsc_config] - platform = "kvm" -EOF -``` - ## Debug When `shim_debug` is enabled in `/etc/containerd/config.toml`, containerd will diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md index a98fe5c4a..132d80927 100644 --- a/g3doc/user_guide/containerd/quick_start.md +++ b/g3doc/user_guide/containerd/quick_start.md @@ -1,7 +1,7 @@ # Containerd Quick Start This document describes how to use `containerd-shim-runsc-v1` with the -containerd runtime handler support on `containerd` 1.2 or later. +containerd runtime handler support on `containerd`. > ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you > may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details. @@ -11,7 +11,8 @@ containerd runtime handler support on `containerd` 1.2 or later. - **runsc** and **containerd-shim-runsc-v1**: See the [installation guide](/docs/user_guide/install/). - **containerd**: See the [containerd website](https://containerd.io/) for - information on how to install containerd. + information on how to install containerd. **Minimal version supported: 1.3.9 + or 1.4.3.** ## Configure containerd diff --git a/images/basic/fsstress/Dockerfile b/images/basic/fsstress/Dockerfile.x86_64 index 84604ead5..21b86065a 100644 --- a/images/basic/fsstress/Dockerfile +++ b/images/basic/fsstress/Dockerfile.x86_64 @@ -1,7 +1,7 @@ # Usage: docker run --rm fsstress -d /test -n 10000 -p 100 -X -v FROM alpine -RUN apk add git +RUN apk update && apk add git RUN git clone https://github.com/linux-test-project/ltp.git --depth 1 WORKDIR /ltp diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index d1ca56370..b84d7c048 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -21,6 +21,7 @@ const ( F_SETFD = 2 F_GETFL = 3 F_SETFL = 4 + F_GETLK = 5 F_SETLK = 6 F_SETLKW = 7 F_SETOWN = 8 @@ -55,7 +56,7 @@ type Flock struct { _ [4]byte Start int64 Len int64 - Pid int32 + PID int32 _ [4]byte } diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 8591acbf2..cb33c37bd 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -416,6 +416,18 @@ type TCPInfo struct { RwndLimited uint64 // SndBufLimited is the time in microseconds limited by send buffer. SndBufLimited uint64 + + Delievered uint32 + DelieveredCe uint32 + + // BytesSent is RFC4898 tcpEStatsPerfHCDataOctetsOut. + BytesSent uint64 + // BytesRetrans is RFC4898 tcpEStatsPerfOctetsRetrans. + BytesRetrans uint64 + // DSACKDups is RFC4898 tcpEStatsStackDSACKDups. + DSACKDups uint32 + // ReordSeen is the number of reordering events seen. + ReordSeen uint32 } // SizeOfTCPInfo is the binary size of a TCPInfo struct. diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go index 2a8d4708b..1a3c0916f 100644 --- a/pkg/abi/linux/tcp.go +++ b/pkg/abi/linux/tcp.go @@ -59,3 +59,12 @@ const ( MAX_TCP_KEEPINTVL = 32767 MAX_TCP_KEEPCNT = 127 ) + +// Congestion control states from include/uapi/linux/tcp.h. +const ( + TCP_CA_Open = 0 + TCP_CA_Disorder = 1 + TCP_CA_CWR = 2 + TCP_CA_Recovery = 3 + TCP_CA_Loss = 4 +) diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index fdfe31417..6f3d72e83 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -26,7 +26,6 @@ import ( "fmt" "io" "sort" - "sync/atomic" "testing" "gvisor.dev/gvisor/pkg/sync" @@ -69,12 +68,18 @@ var globalData struct { } // ClearCoverageData clears existing coverage data. +// +//go:norace func ClearCoverageData() { coverageMu.Lock() defer coverageMu.Unlock() + + // We do not use atomic operations while reading/writing to the counters, + // which would drastically degrade performance. Slight discrepancies due to + // racing is okay for the purposes of kcov. for _, counters := range coverdata.Cover.Counters { for index := 0; index < len(counters); index++ { - atomic.StoreUint32(&counters[index], 0) + counters[index] = 0 } } } @@ -114,6 +119,8 @@ var coveragePool = sync.Pool{ // ensure that each event is only reported once. Due to the limitations of Go // coverage tools, we reset the global coverage data every time this function is // run. +// +//go:norace func ConsumeCoverageData(w io.Writer) int { InitCoverageData() @@ -125,11 +132,14 @@ func ConsumeCoverageData(w io.Writer) int { for fileNum, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { - if atomic.LoadUint32(&counters[index]) == 0 { + // We do not use atomic operations while reading/writing to the counters, + // which would drastically degrade performance. Slight discrepancies due to + // racing is okay for the purposes of kcov. + if counters[index] == 0 { continue } // Non-zero coverage data found; consume it and report as a PC. - atomic.StoreUint32(&counters[index], 0) + counters[index] = 0 pc := globalData.syntheticPCs[fileNum][index] usermem.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index af8230a7d..387f713aa 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -34,7 +34,9 @@ func (p *pipeOperations) beforeSave() { } else if p.flags.Write { file, err := p.opener.NonBlockingOpen(context.Background(), fs.PermMask{Write: true}) if err != nil { - panic(fs.ErrSaveRejection{fmt.Errorf("write-only pipe end cannot be re-opened as %v: %v", p, err)}) + panic(&fs.ErrSaveRejection{ + Err: fmt.Errorf("write-only pipe end cannot be re-opened as %#v: %w", p, err), + }) } file.Close() } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index a020da53b..44587bb37 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -144,7 +144,7 @@ type ErrSaveRejection struct { } // Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { +func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported file system state: " + e.Err.Error() } diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go index a3402e343..141e3c27f 100644 --- a/pkg/sentry/fs/gofer/inode_state.go +++ b/pkg/sentry/fs/gofer/inode_state.go @@ -67,7 +67,9 @@ func (i *inodeFileState) beforeSave() { if i.sattr.Type == fs.RegularFile { uattr, err := i.unstableAttr(&dummyClockContext{context.Background()}) if err != nil { - panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.s.inodeMappings[i.sattr.InodeID], err)}) + panic(&fs.ErrSaveRejection{ + Err: fmt.Errorf("failed to get unstable atttribute of %s: %w", i.s.inodeMappings[i.sattr.InodeID], err), + }) } i.savedUAttr = &uattr } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index a2f3d5918..07b4fb70f 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -257,7 +257,7 @@ func (c *ConnectedEndpoint) Passcred() bool { } // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. -func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index ae3331737..4d3b216d8 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -41,6 +41,8 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sync", "//pkg/waiter", diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 8a5d9c7eb..57686ce07 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -54,6 +54,8 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -83,6 +85,17 @@ const ( // offset 0 to LockEOF. const LockEOF = math.MaxUint64 +// OwnerInfo describes the owner of a lock. +// +// TODO(gvisor.dev/issue/5264): We may need to add other fields in the future +// (e.g., Linux's file_lock.fl_flags to support open file-descriptor locks). +// +// +stateify savable +type OwnerInfo struct { + // PID is the process ID of the lock owner. + PID int32 +} + // Lock is a regional file lock. It consists of either a single writer // or a set of readers. // @@ -92,14 +105,20 @@ const LockEOF = math.MaxUint64 // A Lock may be downgraded from a write lock to a read lock only if // the write lock's uid is the same as the read lock. // +// Accesses to Lock are synchronized through the Locks object to which it +// belongs. +// // +stateify savable type Lock struct { // Readers are the set of read lock holders identified by UniqueID. - // If len(Readers) > 0 then HasWriter must be false. - Readers map[UniqueID]bool + // If len(Readers) > 0 then Writer must be nil. + Readers map[UniqueID]OwnerInfo // Writer holds the writer unique ID. It's nil if there are no writers. Writer UniqueID + + // WriterInfo describes the writer. It is only meaningful if Writer != nil. + WriterInfo OwnerInfo } // Locks is a thread-safe wrapper around a LockSet. @@ -135,14 +154,14 @@ const ( // acquiring the lock in a non-blocking mode or "interrupted" if in a blocking mode. // Blocker is the interface used to provide blocking behavior, passing a nil Blocker // will result in non-blocking behavior. -func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) bool { +func (l *Locks) LockRegion(uid UniqueID, ownerPID int32, t LockType, r LockRange, block Blocker) bool { for { l.mu.Lock() // Blocking locks must run in a loop because we'll be woken up whenever an unlock event // happens for this lock. We will then attempt to take the lock again and if it fails // continue blocking. - res := l.locks.lock(uid, t, r) + res := l.locks.lock(uid, ownerPID, t, r) if !res && block != nil { e, ch := waiter.NewChannelEntry(nil) l.blockedQueue.EventRegister(&e, EventMaskAll) @@ -161,6 +180,14 @@ func (l *Locks) LockRegion(uid UniqueID, t LockType, r LockRange, block Blocker) } } +// LockRegionVFS1 is a wrapper around LockRegion for VFS1, which does not implement +// F_GETLK (and does not care about storing PIDs as a result). +// +// TODO(gvisor.dev/issue/1624): Delete. +func (l *Locks) LockRegionVFS1(uid UniqueID, t LockType, r LockRange, block Blocker) bool { + return l.LockRegion(uid, 0 /* ownerPID */, t, r, block) +} + // UnlockRegion attempts to release a lock for the uid on a region of a file. // This operation is always successful, even if there did not exist a lock on // the requested region held by uid in the first place. @@ -175,13 +202,14 @@ func (l *Locks) UnlockRegion(uid UniqueID, r LockRange) { // makeLock returns a new typed Lock that has either uid as its only reader // or uid as its only writer. -func makeLock(uid UniqueID, t LockType) Lock { - value := Lock{Readers: make(map[UniqueID]bool)} +func makeLock(uid UniqueID, ownerPID int32, t LockType) Lock { + value := Lock{Readers: make(map[UniqueID]OwnerInfo)} switch t { case ReadLock: - value.Readers[uid] = true + value.Readers[uid] = OwnerInfo{PID: ownerPID} case WriteLock: value.Writer = uid + value.WriterInfo = OwnerInfo{PID: ownerPID} default: panic(fmt.Sprintf("makeLock: invalid lock type %d", t)) } @@ -190,17 +218,20 @@ func makeLock(uid UniqueID, t LockType) Lock { // isHeld returns true if uid is a holder of Lock. func (l Lock) isHeld(uid UniqueID) bool { - return l.Writer == uid || l.Readers[uid] + if _, ok := l.Readers[uid]; ok { + return true + } + return l.Writer == uid } // lock sets uid as a holder of a typed lock on Lock. // // Preconditions: canLock is true for the range containing this Lock. -func (l *Lock) lock(uid UniqueID, t LockType) { +func (l *Lock) lock(uid UniqueID, ownerPID int32, t LockType) { switch t { case ReadLock: // If we are already a reader, then this is a no-op. - if l.Readers[uid] { + if _, ok := l.Readers[uid]; ok { return } // We cannot downgrade a write lock to a read lock unless the @@ -210,11 +241,11 @@ func (l *Lock) lock(uid UniqueID, t LockType) { panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer)) } // Ensure that there is only one reader if upgrading. - l.Readers = make(map[UniqueID]bool) + l.Readers = make(map[UniqueID]OwnerInfo) // Ensure that there is no longer a writer. l.Writer = nil } - l.Readers[uid] = true + l.Readers[uid] = OwnerInfo{PID: ownerPID} return case WriteLock: // If we are already the writer, then this is a no-op. @@ -228,13 +259,14 @@ func (l *Lock) lock(uid UniqueID, t LockType) { if readers != 1 { panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, too many readers %v", uid, l.Readers)) } - if !l.Readers[uid] { + if _, ok := l.Readers[uid]; !ok { panic(fmt.Sprintf("lock: cannot upgrade read lock to write lock for uid %d, conflicting reader %v", uid, l.Readers)) } } // Ensure that there is only a writer. - l.Readers = make(map[UniqueID]bool) + l.Readers = make(map[UniqueID]OwnerInfo) l.Writer = uid + l.WriterInfo = OwnerInfo{PID: ownerPID} default: panic(fmt.Sprintf("lock: invalid lock type %d", t)) } @@ -247,7 +279,7 @@ func (l LockSet) lockable(r LockRange, check func(value Lock) bool) bool { // Get our starting point. seg := l.LowerBoundSegment(r.Start) for seg.Ok() && seg.Start() < r.End { - // Note that we don't care about overruning the end of the + // Note that we don't care about overrunning the end of the // last segment because if everything checks out we'll just // split the last segment. if !check(seg.Value()) { @@ -281,7 +313,7 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { if value.Writer == nil { // Then this uid can only take a write lock if this is a private // upgrade, meaning that the only reader is uid. - return len(value.Readers) == 1 && value.Readers[uid] + return value.isOnlyReader(uid) } // If the uid is already a writer on this region, then // adding a write lock would be a no-op. @@ -292,11 +324,19 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool { } } +func (l *Lock) isOnlyReader(uid UniqueID) bool { + if len(l.Readers) != 1 { + return false + } + _, ok := l.Readers[uid] + return ok +} + // lock returns true if uid took a lock of type t on the entire range of // LockRange. // // Preconditions: r.Start <= r.End (will panic otherwise). -func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { +func (l *LockSet) lock(uid UniqueID, ownerPID int32, t LockType, r LockRange) bool { if r.Start > r.End { panic(fmt.Sprintf("lock: r.Start %d > r.End %d", r.Start, r.End)) } @@ -317,7 +357,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { seg, gap := l.Find(r.Start) if gap.Ok() { // Fill in the gap and get the next segment to modify. - seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, t)).NextSegment() + seg = l.Insert(gap, gap.Range().Intersect(r), makeLock(uid, ownerPID, t)).NextSegment() } else if seg.Start() < r.Start { // Get our first segment to modify. _, seg = l.Split(seg, r.Start) @@ -331,12 +371,12 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool { // Set the lock on the segment. This is guaranteed to // always be safe, given canLock above. value := seg.ValuePtr() - value.lock(uid, t) + value.lock(uid, ownerPID, t) // Fill subsequent gaps. gap = seg.NextGap() if gr := gap.Range().Intersect(r); gr.Length() > 0 { - seg = l.Insert(gap, gr, makeLock(uid, t)).NextSegment() + seg = l.Insert(gap, gr, makeLock(uid, ownerPID, t)).NextSegment() } else { seg = gap.NextSegment() } @@ -380,7 +420,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) { // only ever be one writer and no readers, then this // lock should always be removed from the set. remove = true - } else if value.Readers[uid] { + } else if _, ok := value.Readers[uid]; ok { // If uid is the last reader, then just remove the entire // segment. if len(value.Readers) == 1 { @@ -390,7 +430,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) { // affecting any other segment's readers. To do // this, we need to make a copy of the Readers map // and not add this uid. - newValue := Lock{Readers: make(map[UniqueID]bool)} + newValue := Lock{Readers: make(map[UniqueID]OwnerInfo)} for k, v := range value.Readers { if k != uid { newValue.Readers[k] = v @@ -451,3 +491,72 @@ func ComputeRange(start, length, offset int64) (LockRange, error) { // Offset is guaranteed to be positive at this point. return LockRange{Start: uint64(offset), End: end}, nil } + +// TestRegion checks whether the lock holder identified by uid can hold a lock +// of type t on range r. It returns a Flock struct representing this +// information as the F_GETLK fcntl does. +// +// Note that the PID returned in the flock structure is relative to the root PID +// namespace. It needs to be converted to the caller's PID namespace before +// returning to userspace. +// +// TODO(gvisor.dev/issue/5264): we don't support OFD locks through fcntl, which +// would return a struct with pid = -1. +func (l *Locks) TestRegion(ctx context.Context, uid UniqueID, t LockType, r LockRange) linux.Flock { + f := linux.Flock{Type: linux.F_UNLCK} + switch t { + case ReadLock: + l.testRegion(r, func(lock Lock, start, length uint64) bool { + if lock.Writer == nil || lock.Writer == uid { + return true + } + f.Type = linux.F_WRLCK + f.PID = lock.WriterInfo.PID + f.Start = int64(start) + f.Len = int64(length) + return false + }) + case WriteLock: + l.testRegion(r, func(lock Lock, start, length uint64) bool { + if lock.Writer == nil { + for k, v := range lock.Readers { + if k != uid { + // Stop at the first conflict detected. + f.Type = linux.F_RDLCK + f.PID = v.PID + f.Start = int64(start) + f.Len = int64(length) + return false + } + } + return true + } + if lock.Writer == uid { + return true + } + f.Type = linux.F_WRLCK + f.PID = lock.WriterInfo.PID + f.Start = int64(start) + f.Len = int64(length) + return false + }) + default: + panic(fmt.Sprintf("TestRegion: invalid lock type %d", t)) + } + return f +} + +func (l *Locks) testRegion(r LockRange, check func(lock Lock, start, length uint64) bool) { + l.mu.Lock() + defer l.mu.Unlock() + + seg := l.locks.LowerBoundSegment(r.Start) + for seg.Ok() && seg.Start() < r.End { + lock := seg.Value() + if !check(lock, seg.Start(), seg.End()-seg.Start()) { + // Stop at the first conflict detected. + return + } + seg = seg.NextSegment() + } +} diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go index 50a16e662..dcc17c0dc 100644 --- a/pkg/sentry/fs/lock/lock_set_functions.go +++ b/pkg/sentry/fs/lock/lock_set_functions.go @@ -40,7 +40,7 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) return Lock{}, false } for k := range val1.Readers { - if !val2.Readers[k] { + if _, ok := val2.Readers[k]; !ok { return Lock{}, false } } @@ -53,11 +53,12 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock) func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock) { // Copy the segment so that split segments don't contain map references // to other segments. - val0 := Lock{Readers: make(map[UniqueID]bool)} + val0 := Lock{Readers: make(map[UniqueID]OwnerInfo)} for k, v := range val.Readers { val0.Readers[k] = v } val0.Writer = val.Writer + val0.WriterInfo = val.WriterInfo return val, val0 } diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go index fad90984b..9878c04e1 100644 --- a/pkg/sentry/fs/lock/lock_test.go +++ b/pkg/sentry/fs/lock/lock_test.go @@ -30,12 +30,12 @@ func equals(e0, e1 []entry) bool { } for i := range e0 { for k := range e0[i].Lock.Readers { - if !e1[i].Lock.Readers[k] { + if _, ok := e1[i].Lock.Readers[k]; !ok { return false } } for k := range e1[i].Lock.Readers { - if !e0[i].Lock.Readers[k] { + if _, ok := e0[i].Lock.Readers[k]; !ok { return false } } @@ -90,15 +90,15 @@ func TestCanLock(t *testing.T) { // 0 1024 2048 3072 4096 l := fill([]entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 3: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 3: OwnerInfo{}}}, LockRange: LockRange{2048, 3072}, }, { @@ -220,7 +220,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -266,7 +266,7 @@ func TestSetLock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { @@ -283,7 +283,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -302,7 +302,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -333,7 +333,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -351,7 +351,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -366,11 +366,11 @@ func TestSetLock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -383,7 +383,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -398,15 +398,15 @@ func TestSetLock(t *testing.T) { // 0 4096 8192 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, 8192}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{8192, LockEOF}, }, }, @@ -419,7 +419,7 @@ func TestSetLock(t *testing.T) { // 0 1024 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -434,7 +434,7 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -447,7 +447,7 @@ func TestSetLock(t *testing.T) { // 0 4096 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, }, @@ -467,11 +467,11 @@ func TestSetLock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -484,7 +484,7 @@ func TestSetLock(t *testing.T) { // 0 1024 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, LockEOF}, }, }, @@ -499,15 +499,15 @@ func TestSetLock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -520,15 +520,15 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -543,7 +543,7 @@ func TestSetLock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -551,7 +551,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -564,15 +564,15 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -587,7 +587,7 @@ func TestSetLock(t *testing.T) { // 0 1024 3072 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -595,7 +595,7 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{1024, 3072}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -608,11 +608,11 @@ func TestSetLock(t *testing.T) { // 0 1024 2048 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, }, @@ -634,15 +634,15 @@ func TestSetLock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{1024, 2048}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{2048, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -676,7 +676,7 @@ func TestSetLock(t *testing.T) { l := fill(test.before) r := LockRange{Start: test.start, End: test.end} - success := l.lock(test.uid, test.lockType, r) + success := l.lock(test.uid, 0 /* ownerPID */, test.lockType, r) var got []entry for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { got = append(got, entry{ @@ -739,7 +739,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -752,7 +752,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -765,7 +765,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -797,7 +797,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -810,7 +810,7 @@ func TestUnlock(t *testing.T) { // 0 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -849,7 +849,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -862,7 +862,7 @@ func TestUnlock(t *testing.T) { // 0 4096 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}}}, LockRange: LockRange{0, 4096}, }, }, @@ -901,7 +901,7 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { @@ -909,7 +909,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -922,11 +922,11 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -939,7 +939,7 @@ func TestUnlock(t *testing.T) { // 0 max uint64 before: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, LockEOF}, }, }, @@ -952,15 +952,15 @@ func TestUnlock(t *testing.T) { // 0 1024 4096 max uint64 after: []entry{ { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{1024, 4096}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true, 2: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}, 2: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -977,7 +977,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -994,7 +994,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 8}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -1011,7 +1011,7 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{4096, LockEOF}, }, }, @@ -1028,11 +1028,11 @@ func TestUnlock(t *testing.T) { LockRange: LockRange{0, 1024}, }, { - Lock: Lock{Readers: map[UniqueID]bool{1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{1: OwnerInfo{}}}, LockRange: LockRange{4096, 8192}, }, { - Lock: Lock{Readers: map[UniqueID]bool{0: true, 1: true}}, + Lock: Lock{Readers: map[UniqueID]OwnerInfo{0: OwnerInfo{}, 1: OwnerInfo{}}}, LockRange: LockRange{8192, LockEOF}, }, }, diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index e91fa26a4..b44117f40 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -194,16 +193,6 @@ func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions return mfd.inode.Stat(ctx, fs, opts) } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence) -} - // maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid. func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { switch cmd { diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go index 70c68cf0a..a0c5b5af5 100644 --- a/pkg/sentry/fsimpl/devpts/replica.go +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -189,13 +188,3 @@ func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOption fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() return rfd.inode.Stat(ctx, fs, opts) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 0ad79b381..512b70ede 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -311,13 +310,3 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in fd.off = offset return offset, nil } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 4a5539b37..5ad9befcd 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -154,13 +153,3 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 3b5927702..9da01cba3 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -90,6 +90,7 @@ type createSyntheticOpts struct { // * d.isDir(). // * d does not already contain a child with the given name. func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { + now := d.fs.clock.Now().Nanoseconds() child := &dentry{ refs: 1, // held by d fs: d.fs, @@ -98,6 +99,10 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary + atime: now, + mtime: now, + ctime: now, + btime: now, readFD: -1, writeFD: -1, mmapFD: -1, diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 91d5dc174..8f95473b6 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -36,16 +36,26 @@ import ( // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { // Snapshot current syncable dentries and special file FDs. + fs.renameMu.RLock() fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { + // It's safe to use IncRef here even though fs.syncableDentries doesn't + // hold references since we hold fs.renameMu. Note that we can't use + // TryIncRef since cached dentries at zero references should still be + // synced. d.IncRef() ds = append(ds, d) } + fs.renameMu.RUnlock() sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs)) for sffd := range fs.specialFileFDs { - sffd.vfsfd.IncRef() - sffds = append(sffds, sffd) + // As above, fs.specialFileFDs doesn't hold references. However, unlike + // dentries, an FD that has reached zero references can't be + // resurrected, so we can use TryIncRef. + if sffd.vfsfd.TryIncRef() { + sffds = append(sffds, sffd) + } } fs.syncMu.Unlock() diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 3cdb1e659..98f7bc52f 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1944,22 +1944,22 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { +func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("File lock using gofer file handled internally.") }) - return fd.LockFD.LockBSD(ctx, uid, t, block) + return fd.LockFD.LockBSD(ctx, uid, ownerPID, t, block) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("Range lock using gofer file handled internally.") }) - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) + return fd.Locks().LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.Locks().UnlockPOSIX(ctx, uid, r) } diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 36a3f6810..05f11fbd5 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -810,13 +809,3 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) { func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 60acc367f..72aa535f8 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -201,7 +201,7 @@ func (c *ConnectedEndpoint) Passcred() bool { } // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. -func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil } diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index f5c596fec..0f9e20a84 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -370,13 +369,3 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) return syserror.ERESTARTSYS } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 485504995..65054b0ea 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -136,13 +135,3 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { // DynamicBytesFiles are immutable. return syserror.EPERM } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index f8dae22f8..e55111af0 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -275,13 +274,3 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index eac578f25..8139bff76 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -371,6 +371,8 @@ type OrderedChildrenOptions struct { // OrderedChildren may modify the tracked children. This applies to // operations related to rename, unlink and rmdir. If an OrderedChildren is // not writable, these operations all fail with EPERM. + // + // Note that writable users must implement the sticky bit (I_SVTX). Writable bool } @@ -556,7 +558,6 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child Inode) return err } - // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. o.removeLocked(name) return nil } @@ -603,8 +604,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c if err := o.checkExistingLocked(oldname, child); err != nil { return err } + o.removeLocked(oldname) - // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. dst.replaceChildLocked(ctx, newname, child) return nil } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 3492409b2..082fa6504 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -42,7 +42,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -820,13 +819,3 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 75be6129f..fdae163d1 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -518,16 +517,6 @@ func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { // Release implements vfs.FileDescriptionImpl.Release. func (fd *memFD) Release(context.Context) {} -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} - // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -1110,13 +1099,3 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 7ee6227a9..d6f076cd6 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -393,7 +393,7 @@ func TestProcSelf(t *testing.T) { t.Fatalf("CreateTask(): %v", err) } - collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{ + collector := s.WithTemporaryContext(task.AsyncContext()).ListDirents(&vfs.PathOperation{ Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc/self/"), @@ -491,11 +491,11 @@ func TestTree(t *testing.T) { t.Fatalf("CreateTask(): %v", err) } // Add file to populate /proc/[pid]/fd and fdinfo directories. - task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{}) + task.FDTable().NewFDVFS2(task.AsyncContext(), 0, file, kernel.FDFlags{}) tasks = append(tasks, task) } - ctx := tasks[0] + ctx := tasks[0].AsyncContext() fd, err := s.VFS.OpenAt( ctx, auth.CredentialsFromContext(s.Ctx), diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 146c7fdfe..4393cc13b 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -140,35 +140,35 @@ func TestLocks(t *testing.T) { uid1 := 123 uid2 := 456 - if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 1, End: 2}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil { + if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil { t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 0c9c639d3..b32c54e20 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -797,16 +796,6 @@ func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} - // Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all // filesystem state is in-memory. func (*fileDescription) Sync(context.Context) error { diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index a5171b5ad..8645078a0 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -660,7 +660,6 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - vfs.LockFD // d is the corresponding dentry to the fileDescription. d *dentry @@ -1104,14 +1103,29 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) +} + +// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +func (fd *fileDescription) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + return fd.lowerFD.UnlockBSD(ctx) +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return fd.lowerFD.LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.lowerFD.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 30d8b4355..798d6a9bd 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -66,7 +66,7 @@ func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, context.Context, error) { t.Helper() k, err := testutil.Boot() if err != nil { @@ -119,7 +119,7 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, task, nil + return vfsObj, root, task.AsyncContext(), nil } // openVerityAt opens a verity file. diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 0ee60569c..8a5b11d40 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -240,7 +240,6 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", - "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 7aba31587..a6afabb1c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -153,20 +153,12 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { - // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the - // entire file. - err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) + // Release any POSIX lock possibly held by the FDTable. + err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) } - // Generate inotify events. - ev := uint32(linux.IN_CLOSE_NOWRITE) - if file.IsWritable() { - ev = linux.IN_CLOSE_WRITE - } - file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent) - // Drop the table's reference. file.DecRef(ctx) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 303ae8056..ef4e934a1 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -593,8 +593,8 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error { // Wrap this error in ErrSaveRejection so that it will trigger a save // error, rather than a panic. This also allows us to distinguish Fsync // errors from state file errors in state.Save. - return fs.ErrSaveRejection{ - Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err), + return &fs.ErrSaveRejection{ + Err: fmt.Errorf("%q was not sufficiently synced: %w", name, err), } } return nil diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 2c32d017d..71daa9f4b 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -27,7 +27,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index d5a91730d..3b6336e94 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -441,13 +440,3 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } return n, err } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 9419f2e95..ecbe8f920 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. // syserror.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error { +func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error { if !haveDeadline { return t.block(C, nil) } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index fdadb52c0..e9da99067 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -216,7 +216,7 @@ func (ns *PIDNamespace) TaskWithID(tid ThreadID) *Task { return t } -// ThreadGroupWithID returns the thread group lead by the task with thread ID +// ThreadGroupWithID returns the thread group led by the task with thread ID // tid in PID namespace ns. If no task has that TID, or if the task with that // TID is not a thread group leader, ThreadGroupWithID returns nil. func (ns *PIDNamespace) ThreadGroupWithID(tid ThreadID) *ThreadGroup { @@ -292,6 +292,11 @@ func (ns *PIDNamespace) UserNamespace() *auth.UserNamespace { return ns.userns } +// Root returns the root PID namespace of ns. +func (ns *PIDNamespace) Root() *PIDNamespace { + return ns.owner.Root +} + // A threadGroupNode defines the relationship between a thread group and the // rest of the system. Conceptually, threadGroupNode is data belonging to the // owning TaskSet, as if TaskSet contained a field `nodes @@ -485,3 +490,8 @@ func (t *Task) Parent() *Task { func (t *Task) ThreadID() ThreadID { return t.tg.pidns.IDOfTask(t) } + +// TGIDInRoot returns t's TGID in the root PID namespace. +func (t *Task) TGIDInRoot() ThreadID { + return t.tg.pidns.owner.Root.IDOfThreadGroup(t.tg) +} diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go index 6efe5102b..73bfbea49 100644 --- a/pkg/sentry/mm/procfs.go +++ b/pkg/sentry/mm/procfs.go @@ -17,7 +17,6 @@ package mm import ( "bytes" "fmt" - "strings" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile" @@ -165,12 +164,12 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI } if s != "" { // Per linux, we pad until the 74th character. - if pad := 73 - lineLen; pad > 0 { - b.WriteString(strings.Repeat(" ", pad)) + for pad := 73 - lineLen; pad > 0; pad-- { + b.WriteByte(' ') } b.WriteString(s) } - b.WriteString("\n") + b.WriteByte('\n') } // ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 675efdc7c..69e37330b 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -1055,18 +1055,11 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { mm.activeMu.Lock() defer mm.activeMu.Unlock() - // Linux's mm/madvise.c:madvise_dontneed() => mm/memory.c:zap_page_range() - // is analogous to our mm.invalidateLocked(ar, true, true). We inline this - // here, with the special case that we synchronously decommit - // uniquely-owned (non-copy-on-write) pages for private anonymous vma, - // which is the common case for MADV_DONTNEED. Invalidating these pmas, and - // allowing them to be reallocated when touched again, increases pma - // fragmentation, which may significantly reduce performance for - // non-vectored I/O implementations. Also, decommitting synchronously - // ensures that Decommit immediately reduces host memory usage. + // This is invalidateLocked(invalidatePrivate=true, invalidateShared=true), + // with the additional wrinkle that we must refuse to invalidate pmas under + // mlocked vmas. var didUnmapAS bool pseg := mm.pmas.LowerBoundSegment(ar.Start) - mf := mm.mfp.MemoryFile() for vseg := mm.vmas.LowerBoundSegment(ar.Start); vseg.Ok() && vseg.Start() < ar.End; vseg = vseg.NextSegment() { vma := vseg.ValuePtr() if vma.mlockMode != memmap.MLockNone { @@ -1081,20 +1074,8 @@ func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { } } for pseg.Ok() && pseg.Start() < vsegAR.End { - pma := pseg.ValuePtr() - if pma.private && !mm.isPMACopyOnWriteLocked(vseg, pseg) { - psegAR := pseg.Range().Intersect(ar) - if vsegAR.IsSupersetOf(psegAR) && vma.mappable == nil { - if err := mf.Decommit(pseg.fileRangeOf(psegAR)); err == nil { - pseg = pseg.NextSegment() - continue - } - // If an error occurs, fall through to the general - // invalidation case below. - } - } pseg = mm.pmas.Isolate(pseg, vsegAR) - pma = pseg.ValuePtr() + pma := pseg.ValuePtr() if !didUnmapAS { // Unmap all of ar, not just pseg.Range(), to minimize host // syscalls. AddressSpace mappings must be removed before diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index aacd7ce70..17fb0a0d8 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -550,6 +550,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // Wait for the syscall-enter stop. sig := t.wait(stopped) + if sig == syscall.SIGSTOP { + // SIGSTOP was delivered to another thread in the same thread + // group, which initiated another group stop. Just ignore it. + continue + } + // Refresh all registers. if err := t.getRegs(regs); err != nil { panic(fmt.Sprintf("ptrace get regs failed: %v", err)) @@ -566,13 +572,11 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // Is it a system call? if sig == (syscallEvent | syscall.SIGTRAP) { + s.arm64SyscallWorkaround(t, regs) + // Ensure registers are sane. updateSyscallRegs(regs) return true - } else if sig == syscall.SIGSTOP { - // SIGSTOP was delivered to another thread in the same thread - // group, which initiated another group stop. Just ignore it. - continue } // Grab signal information. diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 020bbda79..04815282b 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -257,3 +257,6 @@ func probeSeccomp() bool { } } } + +func (s *subprocess) arm64SyscallWorkaround(t *thread, regs *arch.Registers) { +} diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index bd618fae8..416132967 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -21,6 +21,7 @@ import ( "strings" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -172,3 +173,38 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi func probeSeccomp() bool { return true } + +func (s *subprocess) arm64SyscallWorkaround(t *thread, regs *arch.Registers) { + // On ARM64, when ptrace stops on a system call, it uses the x7 + // register to indicate whether the stop has been signalled from + // syscall entry or syscall exit. This means that we can't get a value + // of this register and we can't change it. More details are in the + // comment for tracehook_report_syscall in arch/arm64/kernel/ptrace.c. + // + // This happens only if we stop on a system call, so let's queue a + // signal, resume a stub thread and catch it on a signal handling. + t.NotifyInterrupt() + for { + if _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + unix.PTRACE_SYSEMU, + uintptr(t.tid), 0, 0, 0, 0); errno != 0 { + panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) + } + + // Wait for the syscall-enter stop. + sig := t.wait(stopped) + if sig == syscall.SIGSTOP { + // SIGSTOP was delivered to another thread in the same thread + // group, which initiated another group stop. Just ignore it. + continue + } + if sig == (syscallEvent | syscall.SIGTRAP) { + t.dumpAndPanic(fmt.Sprintf("unexpected syscall event")) + } + break + } + if err := t.getRegs(regs); err != nil { + panic(fmt.Sprintf("ptrace get regs failed: %v", err)) + } +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index b6ebe29d6..a8e6f172b 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 5b868216d..17f59ba1f 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -377,10 +377,8 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 - case linux.IP_PKTINFO: - optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 9a2cac40b..f82c7c224 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -144,16 +143,6 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return int64(n), err } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} - type socketProviderVFS2 struct { family int } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 70c561cce..2f913787b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -220,18 +219,6 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // A Protocol value of 0 indicates all protocols match. @@ -242,8 +229,11 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -254,12 +244,12 @@ func containsUnsupportedFields4(iptip linux.IPTIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} + const flagMask = 0 // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || + const inverseMask = linux.IPT_INV_DSTIP | linux.IPT_INV_SRCIP | + linux.IPT_INV_VIA_IN | linux.IPT_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 5dbb604f0..263d9d3b5 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -15,7 +15,6 @@ package netfilter import ( - "bytes" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -223,18 +222,6 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) } - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), // In ip6tables a flag controls whether to check the protocol. @@ -245,8 +232,11 @@ func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { Src: tcpip.Address(iptip.Src[:]), SrcMask: tcpip.Address(iptip.SrcMask[:]), SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, + InputInterface: string(trimNullBytes(iptip.InputInterface[:])), + InputInterfaceMask: string(trimNullBytes(iptip.InputInterfaceMask[:])), + InputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_IN != 0, + OutputInterface: string(trimNullBytes(iptip.OutputInterface[:])), + OutputInterfaceMask: string(trimNullBytes(iptip.OutputInterfaceMask[:])), OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, }, nil } @@ -257,14 +247,13 @@ func containsUnsupportedFields6(iptip linux.IP6TIP) bool { // - Dst and DstMask // - Src and SrcMask // - The inverse destination IP check flag + // - InputInterface, InputInterfaceMask and its inverse. // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - flagMask := uint8(linux.IP6T_F_PROTO) + const flagMask = linux.IP6T_F_PROTO // Disable any supported inverse flags. - inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags&^flagMask != 0 || + const inverseMask = linux.IP6T_INV_DSTIP | linux.IP6T_INV_SRCIP | + linux.IP6T_INV_VIA_IN | linux.IP6T_INV_VIA_OUT + return iptip.Flags&^flagMask != 0 || iptip.InverseFlags&^inverseMask != 0 || iptip.TOS != 0 } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 26bd1abd4..7ae18b2a3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -393,3 +394,11 @@ func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkP rev.Revision = maxSupported return rev, nil } + +func trimNullBytes(b []byte) []byte { + n := bytes.IndexByte(b, 0) + if n == -1 { + n = len(b) + } + return b[:n] +} diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 69d13745e..176fa6116 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -112,7 +112,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 352c51390..2740697b3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index c88d8268d..466d5395d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, _, _ string) (bool, bool) { // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. switch pkt.NetworkProtocolNumber { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 1f926aa91..9313e1167 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index 461d524e5..842036764 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -149,13 +148,3 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 22abca120..915134b41 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", - "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -42,7 +41,6 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 22e128b96..94f03af48 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -19,7 +19,7 @@ // be used to expose certain endpoints to the sentry while leaving others out, // for example, TCP endpoints and Unix-domain endpoints. // -// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside +// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. package netstack @@ -55,7 +55,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -194,7 +193,6 @@ var Metrics = tcpip.Stats{ RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."), OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."), OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."), - OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."), OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."), OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."), RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."), @@ -253,11 +251,11 @@ var errStackType = syserr.New("expected but did not receive a netstack.Stack", l type commonEndpoint interface { // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and // transport.Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and // transport.Endpoint.GetRemoteAddress. - GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) // Readiness implements tcpip.Endpoint.Readiness and // transport.Endpoint.Readiness. @@ -265,19 +263,19 @@ type commonEndpoint interface { // SetSockOpt implements tcpip.Endpoint.SetSockOpt and // transport.Endpoint.SetSockOpt. - SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error + SetSockOpt(tcpip.SettableSocketOption) tcpip.Error // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. - GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error + GetSockOpt(tcpip.GettableSocketOption) tcpip.Error // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -285,7 +283,7 @@ type commonEndpoint interface { // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. - LastError() *tcpip.Error + LastError() tcpip.Error // SocketOptions implements tcpip.Endpoint.SocketOptions and // transport.Endpoint.SocketOptions. @@ -440,115 +438,58 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write return int64(res.Count), nil } -// ioSequencePayload implements tcpip.Payload. -// -// t copies user memory bytes on demand based on the requested size. -type ioSequencePayload struct { - ctx context.Context - src usermem.IOSequence -} - -// FullPayload implements tcpip.Payloader.FullPayload -func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { - return i.Payload(int(i.src.NumBytes())) -} - -// Payload implements tcpip.Payloader.Payload. -func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { - if max := int(i.src.NumBytes()); size > max { - size = max - } - v := buffer.NewView(size) - if _, err := i.src.CopyIn(i.ctx, v); err != nil { - // EOF can be returned only if src is a file and this means it - // is in a splice syscall and the error has to be ignored. - if err == io.EOF { - return v, nil - } - return nil, tcpip.ErrBadAddress - } - return v, nil -} - -// DropFirst drops the first n bytes from underlying src. -func (i *ioSequencePayload) DropFirst(n int) { - i.src = i.src.DropFirst(int(n)) -} - // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) - if err == tcpip.ErrWouldBlock { + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { return 0, syserror.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } -// readerPayload implements tcpip.Payloader. -// -// It allocates a view and reads from a reader on-demand, based on available -// capacity in the endpoint. -type readerPayload struct { - ctx context.Context - r io.Reader - count int64 +var _ tcpip.Payloader = (*limitedPayloader)(nil) + +type limitedPayloader struct { + inner io.LimitedReader err error } -// FullPayload implements tcpip.Payloader.FullPayload. -func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { - return r.Payload(int(r.count)) +func (l *limitedPayloader) Read(p []byte) (int, error) { + n, err := l.inner.Read(p) + l.err = err + return n, err } -// Payload implements tcpip.Payloader.Payload. -func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { - if size > int(r.count) { - size = int(r.count) - } - v := buffer.NewView(size) - n, err := r.r.Read(v) - if n > 0 { - // We ignore the error here. It may re-occur on subsequent - // reads, but for now we can enqueue some amount of data. - r.count -= int64(n) - return v[:n], nil - } - if err == syserror.ErrWouldBlock { - return nil, tcpip.ErrWouldBlock - } else if err != nil { - r.err = err // Save for propation. - return nil, tcpip.ErrBadAddress - } - - // There is no data and no error. Return an error, which will propagate - // r.err, which will be nil. This is the desired result: (0, nil). - return nil, tcpip.ErrBadAddress +func (l *limitedPayloader) Len() int { + return int(l.inner.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - f := &readerPayload{ctx: ctx, r: r, count: count} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + f := limitedPayloader{ + inner: io.LimitedReader{ + R: r, + N: count, + }, + } + n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrWouldBlock { - return n, syserror.ErrWouldBlock - } else if err != nil { - return int64(n), f.err // Propagate error. + if _, ok := err.(*tcpip.ErrBadBuffer); ok { + return n, f.err } - - return int64(n), nil + return n, syserr.TranslateNetstackError(err).ToError() } // Readiness returns a mask of ready events for socket s. @@ -592,7 +533,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool if family == linux.AF_UNSPEC { err := s.Endpoint.Disconnect() - if err == tcpip.ErrNotSupported { + if _, ok := err.(*tcpip.ErrNotSupported); ok { return syserr.ErrAddressFamilyNotSupported } return syserr.TranslateNetstackError(err) @@ -614,15 +555,16 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + switch err := s.Endpoint.Connect(addr); err.(type) { + case *tcpip.ErrConnectStarted, *tcpip.ErrAlreadyConnecting: + case *tcpip.ErrNoPortAvailable: if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { // TCP unlike UDP returns EADDRNOTAVAIL when it can't // find an available local ephemeral port. - if err == tcpip.ErrNoPortAvailable { - return syserr.ErrAddressNotAvailable - } + return syserr.ErrAddressNotAvailable } - + return syserr.TranslateNetstackError(err) + default: return syserr.TranslateNetstackError(err) } @@ -680,16 +622,16 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Issue the bind request to the endpoint. err := s.Endpoint.Bind(addr) - if err == tcpip.ErrNoPortAvailable { + if _, ok := err.(*tcpip.ErrNoPortAvailable); ok { // Bind always returns EADDRINUSE irrespective of if the specified port was // already bound or if an ephemeral port was requested but none were // available. // - // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because + // *tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because // UDP connect returns EAGAIN on ephemeral port exhaustion. // // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion. - err = tcpip.ErrPortInUse + err = &tcpip.ErrPortInUse{} } return syserr.TranslateNetstackError(err) @@ -712,7 +654,8 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAdd // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { + ep, wq, err := s.Endpoint.Accept(peerAddr) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { return ep, wq, syserr.TranslateNetstackError(err) } @@ -731,7 +674,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { - if terr != tcpip.ErrWouldBlock || !blocking { + if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } @@ -912,7 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + size, err := ep.SocketOptions().GetSendBufferSize() if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1164,6 +1107,29 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. info := linux.TCPInfo{} + switch v.CcState { + case tcpip.RTORecovery: + info.CaState = linux.TCP_CA_Loss + case tcpip.FastRecovery, tcpip.SACKRecovery: + info.CaState = linux.TCP_CA_Recovery + case tcpip.Disorder: + info.CaState = linux.TCP_CA_Disorder + case tcpip.Open: + info.CaState = linux.TCP_CA_Open + } + info.RTO = uint32(v.RTO / time.Microsecond) + info.RTT = uint32(v.RTT / time.Microsecond) + info.RTTVar = uint32(v.RTTVar / time.Microsecond) + info.SndSsthresh = v.SndSsthresh + info.SndCwnd = v.SndCwnd + + // In netstack reorderSeen is updated only when RACK is enabled. + // We only track whether the reordering is seen, which is + // different than Linux where reorderSeen is not specific to + // RACK and is incremented when a reordering event is seen. + if v.ReorderSeen { + info.ReordSeen = 1 + } // Linux truncates the output binary to outLen. buf := t.CopyScratchBuffer(info.SizeBytes()) @@ -1681,8 +1647,16 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } + family, _, _ := s.Type() + // TODO(gvisor.dev/issue/5132): We currently do not support + // setting this option for unix sockets. + if family == linux.AF_UNIX { + return nil + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) + ep.SocketOptions().SetSendBufferSize(int64(v), true) + return nil case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1814,10 +1788,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam var v linux.Linger binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v) - if v != (linux.Linger{}) { - socket.SetSockOptEmitUnimplementedEvent(t, name) - } - ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, Timeout: time.Second * time.Duration(v.Linger), @@ -2596,7 +2566,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq defer s.readMu.Unlock() res, err := s.Endpoint.Read(w, readOptions) - if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 { + if _, ok := err.(*tcpip.ErrBadBuffer); ok && dst.NumBytes() == 0 { err = nil } if err != nil { @@ -2840,45 +2810,48 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b EndOfRecord: flags&linux.MSG_EOR != 0, } - v := &ioSequencePayload{t, src} - n, err := s.Endpoint.Write(v, opts) - dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= v.src.NumBytes() || dontWait) { - // Complete write. - return int(n), nil - } - if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { - return int(n), syserr.TranslateNetstackError(err) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - v.DropFirst(int(n)) - total := n + r := src.Reader(t) + var ( + total int64 + entry waiter.Entry + ch <-chan struct{} + ) for { - n, err = s.Endpoint.Write(v, opts) - v.DropFirst(int(n)) + n, err := s.Endpoint.Write(r, opts) total += n - - if err != nil && err != tcpip.ErrWouldBlock && total == 0 { - return 0, syserr.TranslateNetstackError(err) - } - - if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { - return int(total), nil - } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(total), syserr.ErrTryAgain + if flags&linux.MSG_DONTWAIT != 0 { + return int(total), syserr.TranslateNetstackError(err) + } + block := true + switch err.(type) { + case nil: + block = total != src.NumBytes() + case *tcpip.ErrWouldBlock: + default: + block = false + } + if block { + if ch == nil { + // We'll have to block. Register for notification and keep trying to + // send all the data. + entry, ch = waiter.NewChannelEntry(nil) + s.EventRegister(&entry, waiter.EventOut) + defer s.EventUnregister(&entry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } + // handleIOError will consume errors from t.Block if needed. + return int(total), syserr.FromError(err) + } } - // handleIOError will consume errors from t.Block if needed. - return int(total), syserr.FromError(err) + continue } + return int(total), syserr.TranslateNetstackError(err) } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 6f70b02fc..24922c400 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -129,20 +128,20 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserror.EOPNOTSUPP } - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) - if err == tcpip.ErrWouldBlock { + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { return 0, syserror.ErrWouldBlock } if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } // Accept implements the linux syscall accept(2) for sockets backed by @@ -155,7 +154,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block } ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { - if terr != tcpip.ErrWouldBlock || !blocking { + if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } @@ -262,13 +261,3 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index c847ff1c7..2515dda80 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -118,7 +118,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* // Create the endpoint. var ep tcpip.Endpoint - var e *tcpip.Error + var e tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 0af805246..ba1cc79e9 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -62,7 +62,7 @@ func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int // Create the endpoint. var ep tcpip.Endpoint - var e *tcpip.Error + var e tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9f7aca305..fc5b823b0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -48,7 +48,7 @@ type ConnectingEndpoint interface { Type() linux.SockType // GetLocalAddress returns the bound path. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Locker protects the following methods. While locked, only the holder of // the lock can change the return value of the protected methods. @@ -128,7 +128,7 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } @@ -173,7 +173,7 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } @@ -296,7 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne) + ne.ops.InitHandler(ne, nil, nil) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0813ad87d..20fa8b874 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,7 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, nil, nil) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 099a56281..70227bbd2 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -169,32 +169,32 @@ type Endpoint interface { Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // GetRemoteAddress returns the address to which the endpoint is // connected. - GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) // SetSockOpt sets a socket option. - SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error + SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error // GetSockOpt gets a socket option. - GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error + GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 // LastError clears and returns the last error reported by the endpoint. - LastError() *tcpip.Error + LastError() tcpip.Error // SocketOptions returns the structure which contains all the socket // level options. @@ -580,7 +580,7 @@ type ConnectedEndpoint interface { Passcred() bool // GetLocalAddress implements Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Send sends a single message. This method does not block. // @@ -640,7 +640,7 @@ type connectedEndpoint struct { Passcred() bool // GetLocalAddress implements Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Type implements Endpoint.Type. Type() linux.SockType @@ -655,7 +655,7 @@ func (e *connectedEndpoint) Passcred() bool { } // GetLocalAddress implements ConnectedEndpoint.GetLocalAddress. -func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return e.endpoint.GetLocalAddress() } @@ -836,13 +836,12 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess } // SetSockOpt sets a socket option. -func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: case tcpip.ReceiveBufferSizeOption: default: log.Warningf("Unsupported socket option: %d", opt) @@ -850,19 +849,40 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket. +func (e *baseEndpoint) IsUnixSocket() bool { + return true +} + +// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. +func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) { + e.Lock() + defer e.Unlock() + + if !e.Connected() { + return -1, &tcpip.ErrNotConnected{} + } + + v := e.connected.SendMaxQueueSize() + if v < 0 { + return -1, &tcpip.ErrQueueSizeNotSupported{} + } + return v, nil +} + +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 e.Lock() if !e.Connected() { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v = int(e.receiver.RecvQueuedSize()) e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return v, nil @@ -870,25 +890,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.Lock() if !e.Connected() { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v := e.connected.SendQueuedSize() e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported - } - return int(v), nil - - case tcpip.SendBufferSizeOption: - e.Lock() - if !e.Connected() { - e.Unlock() - return -1, tcpip.ErrNotConnected - } - v := e.connected.SendMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return int(v), nil @@ -896,29 +903,29 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.Lock() if e.receiver == nil { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v := e.receiver.RecvMaxQueueSize() e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return int(v), nil default: log.Warningf("Unsupported socket option: %d", opt) - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } // LastError implements Endpoint.LastError. -func (*baseEndpoint) LastError() *tcpip.Error { +func (*baseEndpoint) LastError() tcpip.Error { return nil } @@ -958,7 +965,7 @@ func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { } // GetLocalAddress returns the bound path. -func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.Lock() defer e.Unlock() return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil @@ -966,14 +973,14 @@ func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the local address of the connected endpoint (if // available). -func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.Lock() c := e.connected e.Unlock() if c != nil { return c.GetLocalAddress() } - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Release implements BoundEndpoint.Release. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 6c4ec55b2..32e5d2304 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -496,6 +496,9 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b return int(n), syserr.FromError(err) } + // Only send SCM Rights once (see net/unix/af_unix.c:unix_stream_sendmsg). + w.Control.Rights = nil + // We'll have to block. Register for notification and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 27f705bb2..a7d4d7f1f 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -331,16 +330,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) -} - // providerVFS2 is a unix domain socket provider for VFS2. type providerVFS2 struct{} diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index a72df62f6..62d1e8f8b 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -503,8 +503,8 @@ var ARM64 = &kernel.SyscallTable{ 72: syscalls.Supported("pselect", Pselect), 73: syscalls.Supported("ppoll", Ppoll), 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), - 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 76: syscalls.Supported("splice", Splice), 77: syscalls.Supported("tee", Tee), 78: syscalls.Supported("readlinkat", Readlinkat), 79: syscalls.Supported("fstatat", Fstatat), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index c33571f43..a6253626e 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1014,12 +1014,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } @@ -1030,12 +1030,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } if cmd == linux.F_SETLK { // Non-blocking lock, provide a nil lock.Blocker. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, nil) { return 0, nil, syserror.EAGAIN } } else { // Blocking lock, pass in the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } @@ -2167,24 +2167,24 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.LOCK_EX: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, t) { return 0, nil, syserror.EINTR } } case linux.LOCK_SH: if nonblocking { // Since we're nonblocking we pass a nil lock.Blocker implementation. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, nil) { return 0, nil, syserror.EWOULDBLOCK } } else { // Because we're blocking we will pass the task to satisfy the lock.Blocker interface. - if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) { + if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, t) { return 0, nil, syserror.EINTR } } diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 7dd9ef857..1a31898e8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -123,6 +123,15 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + switch cmd { + case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC, linux.F_GETFD, linux.F_SETFD, linux.F_GETFL: + // allowed + default: + return 0, nil, syserror.EBADF + } + } + switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: minfd := args[2].Int() @@ -205,8 +214,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } err := tmpfs.AddSeals(file, args[2].Uint()) return 0, nil, err - case linux.F_SETLK, linux.F_SETLKW: - return 0, nil, posixLock(t, args, file, cmd) + case linux.F_SETLK: + return 0, nil, posixLock(t, args, file, false /* blocking */) + case linux.F_SETLKW: + return 0, nil, posixLock(t, args, file, true /* blocking */) + case linux.F_GETLK: + return 0, nil, posixTestLock(t, args, file) case linux.F_GETSIG: a := file.AsyncHandler() if a == nil { @@ -292,7 +305,49 @@ func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, } } -func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error { +func posixTestLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription) error { + // Copy in the lock request. + flockAddr := args[2].Pointer() + var flock linux.Flock + if _, err := flock.CopyIn(t, flockAddr); err != nil { + return err + } + var typ lock.LockType + switch flock.Type { + case linux.F_RDLCK: + typ = lock.ReadLock + case linux.F_WRLCK: + typ = lock.WriteLock + default: + return syserror.EINVAL + } + r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence) + if err != nil { + return err + } + + newFlock, err := file.TestPOSIX(t, t.FDTable(), typ, r) + if err != nil { + return err + } + newFlock.PID = translatePID(t.PIDNamespace().Root(), t.PIDNamespace(), newFlock.PID) + if _, err = newFlock.CopyOut(t, flockAddr); err != nil { + return err + } + return nil +} + +// translatePID translates a pid from one namespace to another. Note that this +// may race with task termination/creation, in which case the original task +// corresponding to pid may no longer exist. This is used to implement the +// F_GETLK fcntl, which has the same potential race in Linux as well (i.e., +// there is no synchronization between retrieving the lock PID and translating +// it). See fs/locks.c:posix_lock_to_flock. +func translatePID(old, new *kernel.PIDNamespace, pid int32) int32 { + return int32(new.IDOfTask(old.TaskWithID(kernel.ThreadID(pid)))) +} + +func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, blocking bool) error { // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock @@ -301,25 +356,30 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip } var blocker lock.Blocker - if cmd == linux.F_SETLKW { + if blocking { blocker = t } + r, err := file.ComputeLockRange(t, uint64(flock.Start), uint64(flock.Len), flock.Whence) + if err != nil { + return err + } + switch flock.Type { case linux.F_RDLCK: if !file.IsReadable() { return syserror.EBADF } - return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.ReadLock, r, blocker) case linux.F_WRLCK: if !file.IsWritable() { return syserror.EBADF } - return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + return file.LockPOSIX(t, t.FDTable(), int32(t.TGIDInRoot()), lock.WriteLock, r, blocker) case linux.F_UNLCK: - return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence) + return file.UnlockPOSIX(t, t.FDTable(), r) default: return syserror.EINVAL @@ -344,6 +404,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + // If the FD refers to a pipe or FIFO, return error. if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { return 0, nil, syserror.ESPIPE diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 20c264fef..c7c3fed57 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -32,6 +32,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + // Handle ioctls that apply to all FDs. switch args[1].Int() { case linux.FIONCLEX: diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go index b910b5a74..d1452a04d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/lock.go +++ b/pkg/sentry/syscalls/linux/vfs2/lock.go @@ -44,11 +44,11 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch operation { case linux.LOCK_EX: - if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil { + if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.WriteLock, blocker); err != nil { return 0, nil, err } case linux.LOCK_SH: - if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil { + if err := file.LockBSD(t, int32(t.TGIDInRoot()), lock.ReadLock, blocker); err != nil { return 0, nil, err } case linux.LOCK_UN: diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index b77b29dcc..c7417840f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -93,17 +93,11 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { n, err := file.Read(t, dst, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -134,9 +128,6 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt } file.EventUnregister(&w) - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } @@ -257,17 +248,11 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { n, err := file.PRead(t, dst, offset, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -297,10 +282,6 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } @@ -363,17 +344,11 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { n, err := file.Write(t, src, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } @@ -403,10 +378,6 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return total, err } @@ -527,17 +498,11 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { n, err := file.PWrite(t, src, offset, opts) if err != syserror.ErrWouldBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } return n, err } allowBlock, deadline, hasDeadline := blockPolicy(t, file) if !allowBlock { - if n > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return n, err } @@ -567,10 +532,6 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o } } file.EventUnregister(&w) - - if total > 0 { - file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - } return total, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 1ee37e5a8..903169dc2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -220,7 +220,6 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys length := args[3].Int64() file := t.GetFileVFS2(fd) - if file == nil { return 0, nil, syserror.EBADF } @@ -229,23 +228,18 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if !file.IsWritable() { return 0, nil, syserror.EBADF } - if mode != 0 { return 0, nil, syserror.ENOTSUP } - if offset < 0 || length <= 0 { return 0, nil, syserror.EINVAL } size := offset + length - if size < 0 { return 0, nil, syserror.EFBIG } - limit := limits.FromContext(t).Get(limits.FileSize).Cur - if uint64(size) >= limit { t.SendSignal(&arch.SignalInfo{ Signo: int32(linux.SIGXFSZ), @@ -254,12 +248,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EFBIG } - if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil { - return 0, nil, err - } - - file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - return 0, nil, nil + return 0, nil, file.Allocate(t, mode, uint64(offset), uint64(length)) } // Utime implements Linux syscall utime(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 8bb763a47..19e175203 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -170,13 +170,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } - if n != 0 { - // On Linux, inotify behavior is not very consistent with splice(2). We try - // our best to emulate Linux for very basic calls to splice, where for some - // reason, events are generated for output files, but not input files. - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - } - // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile) @@ -256,8 +249,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo } if n != 0 { - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - // If a partial write is completed, the error is dropped. Log it here. if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { log.Debugf("tee completed a partial write with error: %v", err) @@ -449,9 +440,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } if total != 0 { - inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { // If a partial write is completed, the error is dropped. Log it here. log.Debugf("sendfile completed a partial write with error: %v", err) diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 6e9b599e2..1f8a5878c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -36,6 +36,10 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + return 0, nil, file.SyncFS(t) } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index a3868bf16..df4990854 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -83,6 +83,7 @@ go_library( "mount.go", "mount_namespace_refs.go", "mount_unsafe.go", + "opath.go", "options.go", "pathname.go", "permissions.go", diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 5321ac80a..f612a71b2 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -161,6 +161,13 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef(ctx context.Context) { fd.FileDescriptionRefs.DecRef(func() { + // Generate inotify events. + ev := uint32(linux.IN_CLOSE_NOWRITE) + if fd.IsWritable() { + ev = linux.IN_CLOSE_WRITE + } + fd.Dentry().InotifyWithParent(ctx, ev, 0, PathEvent) + // Unregister fd from all epoll instances. fd.epollMu.Lock() epolls := fd.epolls @@ -448,16 +455,19 @@ type FileDescriptionImpl interface { RemoveXattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. - LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error + LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error // UnlockBSD releases a BSD-style advisory file lock. UnlockBSD(ctx context.Context, uid lock.UniqueID) error // LockPOSIX tries to acquire a POSIX-style advisory file lock. - LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error + LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error // UnlockPOSIX releases a POSIX-style advisory file lock. - UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error + UnlockPOSIX(ctx context.Context, uid lock.UniqueID, ComputeLockRange lock.LockRange) error + + // TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl. + TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error) } // Dirent holds the information contained in struct linux_dirent64. @@ -556,7 +566,11 @@ func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length ui if !fd.IsWritable() { return syserror.EBADF } - return fd.impl.Allocate(ctx, mode, offset, length) + if err := fd.impl.Allocate(ctx, mode, offset, length); err != nil { + return err + } + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + return nil } // Readiness implements waiter.Waitable.Readiness. @@ -592,6 +606,9 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } start := fsmetric.StartReadWait() n, err := fd.impl.PRead(ctx, dst, offset, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent) + } fsmetric.Reads.Increment() fsmetric.FinishReadWait(fsmetric.ReadWait, start) return n, err @@ -604,6 +621,9 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt } start := fsmetric.StartReadWait() n, err := fd.impl.Read(ctx, dst, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, PathEvent) + } fsmetric.Reads.Increment() fsmetric.FinishReadWait(fsmetric.ReadWait, start) return n, err @@ -619,7 +639,11 @@ func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, o if !fd.writable { return 0, syserror.EBADF } - return fd.impl.PWrite(ctx, src, offset, opts) + n, err := fd.impl.PWrite(ctx, src, offset, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + } + return n, err } // Write is similar to PWrite, but does not specify an offset. @@ -627,7 +651,11 @@ func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, op if !fd.writable { return 0, syserror.EBADF } - return fd.impl.Write(ctx, src, opts) + n, err := fd.impl.Write(ctx, src, opts) + if n > 0 { + fd.Dentry().InotifyWithParent(ctx, linux.IN_MODIFY, 0, PathEvent) + } + return n, err } // IterDirents invokes cb on each entry in the directory represented by fd. If @@ -791,9 +819,9 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e } // LockBSD tries to acquire a BSD-style advisory file lock. -func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error { +func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error { atomic.StoreUint32(&fd.usedLockBSD, 1) - return fd.impl.LockBSD(ctx, fd, lockType, blocker) + return fd.impl.LockBSD(ctx, fd, ownerPID, lockType, blocker) } // UnlockBSD releases a BSD-style advisory file lock. @@ -802,13 +830,45 @@ func (fd *FileDescription) UnlockBSD(ctx context.Context) error { } // LockPOSIX locks a POSIX-style file range lock. -func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error { - return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block) +func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, r lock.LockRange, block lock.Blocker) error { + return fd.impl.LockPOSIX(ctx, uid, ownerPID, t, r, block) } // UnlockPOSIX unlocks a POSIX-style file range lock. -func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error { - return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence) +func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, r lock.LockRange) error { + return fd.impl.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX returns information about whether the specified lock can be held. +func (fd *FileDescription) TestPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, r lock.LockRange) (linux.Flock, error) { + return fd.impl.TestPOSIX(ctx, uid, t, r) +} + +// ComputeLockRange computes the range of a file lock based on the given values. +func (fd *FileDescription) ComputeLockRange(ctx context.Context, start uint64, length uint64, whence int16) (lock.LockRange, error) { + var off int64 + switch whence { + case linux.SEEK_SET: + off = 0 + case linux.SEEK_CUR: + // Note that Linux does not hold any mutexes while retrieving the file + // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. + curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) + if err != nil { + return lock.LockRange{}, err + } + off = curOff + case linux.SEEK_END: + stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) + if err != nil { + return lock.LockRange{}, err + } + off = int64(stat.Size) + default: + return lock.LockRange{}, syserror.EINVAL + } + + return lock.ComputeRange(int64(start), int64(length), off) } // A FileAsync sends signals to its owner when w is ready for IO. This is only diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 48ca9de44..eb7d2fd3b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -419,8 +419,8 @@ func (fd *LockFD) Locks() *FileLocks { } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - return fd.locks.LockBSD(uid, t, block) +func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return fd.locks.LockBSD(ctx, uid, ownerPID, t, block) } // UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. @@ -429,6 +429,21 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return nil } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return fd.locks.UnlockPOSIX(ctx, uid, r) +} + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fd.locks.TestPOSIX(ctx, uid, t, r) +} + // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. // @@ -436,7 +451,7 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { type NoLockFD struct{} // LockBSD implements vfs.FileDescriptionImpl.LockBSD. -func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { +func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return syserror.ENOLCK } @@ -446,11 +461,16 @@ func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { +func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return syserror.ENOLCK } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { +func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return syserror.ENOLCK } + +// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return linux.Flock{}, syserror.ENOLCK +} diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 1ff202f2a..cbe4d8c2d 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -39,8 +39,8 @@ type FileLocks struct { } // LockBSD tries to acquire a BSD-style lock on the entire file. -func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { - if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { +func (fl *FileLocks) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerID int32, t fslock.LockType, block fslock.Blocker) error { + if fl.bsd.LockRegion(uid, ownerID, t, fslock.LockRange{0, fslock.LockEOF}, block) { return nil } @@ -61,12 +61,8 @@ func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { } // LockPOSIX tries to acquire a POSIX-style lock on a file region. -func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - rng, err := computeRange(ctx, fd, start, length, whence) - if err != nil { - return err - } - if fl.posix.LockRegion(uid, t, rng, block) { +func (fl *FileLocks) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + if fl.posix.LockRegion(uid, ownerPID, t, r, block) { return nil } @@ -82,37 +78,12 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fsl // // This operation is always successful, even if there did not exist a lock on // the requested region held by uid in the first place. -func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error { - rng, err := computeRange(ctx, fd, start, length, whence) - if err != nil { - return err - } - fl.posix.UnlockRegion(uid, rng) +func (fl *FileLocks) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + fl.posix.UnlockRegion(uid, r) return nil } -func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) { - var off int64 - switch whence { - case linux.SEEK_SET: - off = 0 - case linux.SEEK_CUR: - // Note that Linux does not hold any mutexes while retrieving the file - // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. - curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) - if err != nil { - return fslock.LockRange{}, err - } - off = curOff - case linux.SEEK_END: - stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) - if err != nil { - return fslock.LockRange{}, err - } - off = int64(stat.Size) - default: - return fslock.LockRange{}, syserror.EINVAL - } - - return fslock.ComputeRange(int64(start), int64(length), off) +// TestPOSIX returns information about whether the specified lock can be held, in the style of the F_GETLK fcntl. +func (fl *FileLocks) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return fl.posix.TestRegion(ctx, uid, t, r), nil } diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go new file mode 100644 index 000000000..39fbac987 --- /dev/null +++ b/pkg/sentry/vfs/opath.go @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH. +// +// +stateify savable +type opathFD struct { + vfsfd FileDescription + FileDescriptionDefaultImpl + NoLockFD +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *opathFD) Release(context.Context) { + // noop +} + +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EBADF +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return 0, syserror.EBADF +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return syserror.EBADF +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, syserror.EBADF +} + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return syserror.EBADF +} + +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return nil, syserror.EBADF +} + +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { + return "", syserror.EBADF +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error { + return syserror.EBADF +} + +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error { + return syserror.EBADF +} + +// Sync implements vfs.FileDescriptionImpl.Sync. +func (fd *opathFD) Sync(ctx context.Context) error { + return syserror.EBADF +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error { + return syserror.EBADF +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + vfsObj := fd.vfsfd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vfsfd.vd, + Start: fd.vfsfd.vd, + }) + stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) + vfsObj.putResolvingPath(ctx, rp) + return stat, err +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { + vfsObj := fd.vfsfd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vfsfd.vd, + Start: fd.vfsfd.vd, + }) + statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) + vfsObj.putResolvingPath(ctx, rp) + return statfs, err +} diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index bc79e5ecc..c9907843c 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -129,7 +129,7 @@ type OpenOptions struct { // // FilesystemImpls are responsible for implementing the following flags: // O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC, - // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and + // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_SYNC, O_TMPFILE, and // O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and // O_NOFOLLOW. VFS users are responsible for handling O_CLOEXEC, since file // descriptors are mostly outside the scope of VFS. diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 6fd1bb0b2..0aff2dd92 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -425,6 +425,18 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp.mustBeDir = true rp.mustBeDirOrig = true } + if opts.Flags&linux.O_PATH != 0 { + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return nil, err + } + fd := &opathFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}); err != nil { + return nil, err + } + vd.DecRef(ctx) + return &fd.vfsfd, err + } for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 28e62abbb..2e2395807 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -67,6 +67,7 @@ go_library( ], marshal = False, stateify = False, + visibility = ["//:sandbox"], deps = [ "//pkg/goid", ], diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index cb8981633..a6a91e064 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -23,95 +23,114 @@ import ( // Mapping for tcpip.Error types. var ( - ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) - ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) - ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) - ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) - ErrDuplicateAddress = New(tcpip.ErrDuplicateAddress.String(), linux.EEXIST) - ErrBadLinkEndpoint = New(tcpip.ErrBadLinkEndpoint.String(), linux.EINVAL) - ErrAlreadyBound = New(tcpip.ErrAlreadyBound.String(), linux.EINVAL) - ErrInvalidEndpointState = New(tcpip.ErrInvalidEndpointState.String(), linux.EINVAL) - ErrAlreadyConnecting = New(tcpip.ErrAlreadyConnecting.String(), linux.EALREADY) - ErrNoPortAvailable = New(tcpip.ErrNoPortAvailable.String(), linux.EAGAIN) - ErrPortInUse = New(tcpip.ErrPortInUse.String(), linux.EADDRINUSE) - ErrBadLocalAddress = New(tcpip.ErrBadLocalAddress.String(), linux.EADDRNOTAVAIL) - ErrClosedForSend = New(tcpip.ErrClosedForSend.String(), linux.EPIPE) - ErrClosedForReceive = New(tcpip.ErrClosedForReceive.String(), nil) - ErrTimeout = New(tcpip.ErrTimeout.String(), linux.ETIMEDOUT) - ErrAborted = New(tcpip.ErrAborted.String(), linux.EPIPE) - ErrConnectStarted = New(tcpip.ErrConnectStarted.String(), linux.EINPROGRESS) - ErrDestinationRequired = New(tcpip.ErrDestinationRequired.String(), linux.EDESTADDRREQ) - ErrNotSupported = New(tcpip.ErrNotSupported.String(), linux.EOPNOTSUPP) - ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY) - ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT) - ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL) - ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES) - ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) - ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT) + ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL) + ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV) + ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV) + ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT) + ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST) + ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST) + ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL) + ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL) + ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY) + ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN) + ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) + ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) + ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil) + ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) + ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) + ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) + ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ) + ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP) + ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY) + ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT) + ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL) + ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES) + ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) + ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) ) -var netstackErrorTranslations map[string]*Error - -func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) { - key := tcpipErr.String() - if _, ok := netstackErrorTranslations[key]; ok { - panic(fmt.Sprintf("duplicate error key: %s", key)) - } - netstackErrorTranslations[key] = netstackErr -} - -func init() { - netstackErrorTranslations = make(map[string]*Error) - addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol) - addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID) - addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice) - addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption) - addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID) - addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress) - addErrMapping(tcpip.ErrNoRoute, ErrNoRoute) - addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint) - addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound) - addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState) - addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting) - addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected) - addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable) - addErrMapping(tcpip.ErrPortInUse, ErrPortInUse) - addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress) - addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend) - addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive) - addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock) - addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused) - addErrMapping(tcpip.ErrTimeout, ErrTimeout) - addErrMapping(tcpip.ErrAborted, ErrAborted) - addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted) - addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired) - addErrMapping(tcpip.ErrNotSupported, ErrNotSupported) - addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported) - addErrMapping(tcpip.ErrNotConnected, ErrNotConnected) - addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset) - addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) - addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) - addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) - addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) - addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) - addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) - addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace) - addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) - addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) - addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) - addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer) -} - // TranslateNetstackError converts an error from the tcpip package to a sentry // internal error. -func TranslateNetstackError(err *tcpip.Error) *Error { - if err == nil { +func TranslateNetstackError(err tcpip.Error) *Error { + switch err.(type) { + case nil: return nil + case *tcpip.ErrUnknownProtocol: + return ErrUnknownProtocol + case *tcpip.ErrUnknownNICID: + return ErrUnknownNICID + case *tcpip.ErrUnknownDevice: + return ErrUnknownDevice + case *tcpip.ErrUnknownProtocolOption: + return ErrUnknownProtocolOption + case *tcpip.ErrDuplicateNICID: + return ErrDuplicateNICID + case *tcpip.ErrDuplicateAddress: + return ErrDuplicateAddress + case *tcpip.ErrNoRoute: + return ErrNoRoute + case *tcpip.ErrAlreadyBound: + return ErrAlreadyBound + case *tcpip.ErrInvalidEndpointState: + return ErrInvalidEndpointState + case *tcpip.ErrAlreadyConnecting: + return ErrAlreadyConnecting + case *tcpip.ErrAlreadyConnected: + return ErrAlreadyConnected + case *tcpip.ErrNoPortAvailable: + return ErrNoPortAvailable + case *tcpip.ErrPortInUse: + return ErrPortInUse + case *tcpip.ErrBadLocalAddress: + return ErrBadLocalAddress + case *tcpip.ErrClosedForSend: + return ErrClosedForSend + case *tcpip.ErrClosedForReceive: + return ErrClosedForReceive + case *tcpip.ErrWouldBlock: + return ErrWouldBlock + case *tcpip.ErrConnectionRefused: + return ErrConnectionRefused + case *tcpip.ErrTimeout: + return ErrTimeout + case *tcpip.ErrAborted: + return ErrAborted + case *tcpip.ErrConnectStarted: + return ErrConnectStarted + case *tcpip.ErrDestinationRequired: + return ErrDestinationRequired + case *tcpip.ErrNotSupported: + return ErrNotSupported + case *tcpip.ErrQueueSizeNotSupported: + return ErrQueueSizeNotSupported + case *tcpip.ErrNotConnected: + return ErrNotConnected + case *tcpip.ErrConnectionReset: + return ErrConnectionReset + case *tcpip.ErrConnectionAborted: + return ErrConnectionAborted + case *tcpip.ErrNoSuchFile: + return ErrNoSuchFile + case *tcpip.ErrInvalidOptionValue: + return ErrInvalidOptionValue + case *tcpip.ErrBadAddress: + return ErrBadAddress + case *tcpip.ErrNetworkUnreachable: + return ErrNetworkUnreachable + case *tcpip.ErrMessageTooLong: + return ErrMessageTooLong + case *tcpip.ErrNoBufferSpace: + return ErrNoBufferSpace + case *tcpip.ErrBroadcastDisabled: + return ErrBroadcastDisabled + case *tcpip.ErrNotPermitted: + return ErrNotPermittedNet + case *tcpip.ErrAddressFamilyNotSupported: + return ErrAddressFamilyNotSupported + case *tcpip.ErrBadBuffer: + return ErrBadBuffer + default: + panic(fmt.Sprintf("unknown error %T", err)) } - se, ok := netstackErrorTranslations[err.String()] - if !ok { - panic("Unknown error: " + err.String()) - } - return se } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e7924e5c2..f979d22f0 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "tcpip", srcs = [ + "errors.go", "sock_err_list.go", "socketops.go", "tcpip.go", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..c188aaa18 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -247,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) l.wq.EventRegister(&waitEntry, waiter.EventIn) @@ -256,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { for { n, wq, err = l.ep.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } @@ -297,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} res, err := ep.Read(&w, opts) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { res, err = ep.Read(&w, opts) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } select { @@ -315,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s } } - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { return 0, io.EOF } @@ -354,10 +355,8 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // 1. Write may write nothing and return *tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have // not already and wait to try again. // 2. Write may write fewer than the full number of bytes and return @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err.(type) { + case nil: + case *tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -502,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): ep.Close() @@ -644,17 +637,19 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) - if err == tcpip.ErrWouldBlock { + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.wq.EventRegister(&waitEntry, waiter.EventOut) @@ -666,8 +661,8 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) - if err != tcpip.ErrWouldBlock { + n, err = c.ep.Write(&r, writeOptions) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index b196324c7..2b3ea4bdf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) { } } -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { +func newLoopbackStack() (*stack.Stack, tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, @@ -94,7 +94,7 @@ type testConnection struct { ep tcpip.Endpoint } -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { +func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er wq.EventRegister(&entry, waiter.EventOut) err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { <-ch err = ep.LastError() } @@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) + switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { + case *net.OpError: + if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { + t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) + } + default: + t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) } } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 0ac2000ca..07b4393a4 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -234,7 +234,7 @@ func IPv4RouterAlert() NetworkChecker { for { opt, done, err := iterator.Next() if err != nil { - t.Fatalf("error acquiring next IPv4 option %s", err) + t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer) } if done { break diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go new file mode 100644 index 000000000..af46da1d2 --- /dev/null +++ b/pkg/tcpip/errors.go @@ -0,0 +1,538 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" +) + +// Error represents an error in the netstack error space. +// +// The error interface is intentionally omitted to avoid loss of type +// information that would occur if these errors were passed as error. +type Error interface { + isError() + + // IgnoreStats indicates whether this error should be included in failure + // counts in tcpip.Stats structs. + IgnoreStats() bool + + fmt.Stringer +} + +// ErrAborted indicates the operation was aborted. +// +// +stateify savable +type ErrAborted struct{} + +func (*ErrAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrAborted) IgnoreStats() bool { + return false +} +func (*ErrAborted) String() string { + return "operation aborted" +} + +// ErrAddressFamilyNotSupported indicates the operation does not support the +// given address family. +// +// +stateify savable +type ErrAddressFamilyNotSupported struct{} + +func (*ErrAddressFamilyNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrAddressFamilyNotSupported) IgnoreStats() bool { + return false +} +func (*ErrAddressFamilyNotSupported) String() string { + return "address family not supported by protocol" +} + +// ErrAlreadyBound indicates the endpoint is already bound. +// +// +stateify savable +type ErrAlreadyBound struct{} + +func (*ErrAlreadyBound) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyBound) IgnoreStats() bool { + return true +} +func (*ErrAlreadyBound) String() string { return "endpoint already bound" } + +// ErrAlreadyConnected indicates the endpoint is already connected. +// +// +stateify savable +type ErrAlreadyConnected struct{} + +func (*ErrAlreadyConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnected) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" } + +// ErrAlreadyConnecting indicates the endpoint is already connecting. +// +// +stateify savable +type ErrAlreadyConnecting struct{} + +func (*ErrAlreadyConnecting) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnecting) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" } + +// ErrBadAddress indicates a bad address was provided. +// +// +stateify savable +type ErrBadAddress struct{} + +func (*ErrBadAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadAddress) IgnoreStats() bool { + return false +} +func (*ErrBadAddress) String() string { return "bad address" } + +// ErrBadBuffer indicates a bad buffer was provided. +// +// +stateify savable +type ErrBadBuffer struct{} + +func (*ErrBadBuffer) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadBuffer) IgnoreStats() bool { + return false +} +func (*ErrBadBuffer) String() string { return "bad buffer" } + +// ErrBadLocalAddress indicates a bad local address was provided. +// +// +stateify savable +type ErrBadLocalAddress struct{} + +func (*ErrBadLocalAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadLocalAddress) IgnoreStats() bool { + return false +} +func (*ErrBadLocalAddress) String() string { return "bad local address" } + +// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint. +// +// +stateify savable +type ErrBroadcastDisabled struct{} + +func (*ErrBroadcastDisabled) isError() {} + +// IgnoreStats implements Error. +func (*ErrBroadcastDisabled) IgnoreStats() bool { + return false +} +func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" } + +// ErrClosedForReceive indicates the endpoint is closed for incoming data. +// +// +stateify savable +type ErrClosedForReceive struct{} + +func (*ErrClosedForReceive) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForReceive) IgnoreStats() bool { + return false +} +func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" } + +// ErrClosedForSend indicates the endpoint is closed for outgoing data. +// +// +stateify savable +type ErrClosedForSend struct{} + +func (*ErrClosedForSend) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForSend) IgnoreStats() bool { + return false +} +func (*ErrClosedForSend) String() string { return "endpoint is closed for send" } + +// ErrConnectStarted indicates the endpoint is connecting asynchronously. +// +// +stateify savable +type ErrConnectStarted struct{} + +func (*ErrConnectStarted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectStarted) IgnoreStats() bool { + return true +} +func (*ErrConnectStarted) String() string { return "connection attempt started" } + +// ErrConnectionAborted indicates the connection was aborted. +// +// +stateify savable +type ErrConnectionAborted struct{} + +func (*ErrConnectionAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionAborted) IgnoreStats() bool { + return false +} +func (*ErrConnectionAborted) String() string { return "connection aborted" } + +// ErrConnectionRefused indicates the connection was refused. +// +// +stateify savable +type ErrConnectionRefused struct{} + +func (*ErrConnectionRefused) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionRefused) IgnoreStats() bool { + return false +} +func (*ErrConnectionRefused) String() string { return "connection was refused" } + +// ErrConnectionReset indicates the connection was reset. +// +// +stateify savable +type ErrConnectionReset struct{} + +func (*ErrConnectionReset) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionReset) IgnoreStats() bool { + return false +} +func (*ErrConnectionReset) String() string { return "connection reset by peer" } + +// ErrDestinationRequired indicates the operation requires a destination +// address, and one was not provided. +// +// +stateify savable +type ErrDestinationRequired struct{} + +func (*ErrDestinationRequired) isError() {} + +// IgnoreStats implements Error. +func (*ErrDestinationRequired) IgnoreStats() bool { + return false +} +func (*ErrDestinationRequired) String() string { return "destination address is required" } + +// ErrDuplicateAddress indicates the operation encountered a duplicate address. +// +// +stateify savable +type ErrDuplicateAddress struct{} + +func (*ErrDuplicateAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateAddress) IgnoreStats() bool { + return false +} +func (*ErrDuplicateAddress) String() string { return "duplicate address" } + +// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID. +// +// +stateify savable +type ErrDuplicateNICID struct{} + +func (*ErrDuplicateNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateNICID) IgnoreStats() bool { + return false +} +func (*ErrDuplicateNICID) String() string { return "duplicate nic id" } + +// ErrInvalidEndpointState indicates the endpoint is in an invalid state. +// +// +stateify savable +type ErrInvalidEndpointState struct{} + +func (*ErrInvalidEndpointState) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidEndpointState) IgnoreStats() bool { + return false +} +func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" } + +// ErrInvalidOptionValue indicates an invalid option value was provided. +// +// +stateify savable +type ErrInvalidOptionValue struct{} + +func (*ErrInvalidOptionValue) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidOptionValue) IgnoreStats() bool { + return false +} +func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } + +// ErrMalformedHeader indicates the operation encountered a malformed header. +// +// +stateify savable +type ErrMalformedHeader struct{} + +func (*ErrMalformedHeader) isError() {} + +// IgnoreStats implements Error. +func (*ErrMalformedHeader) IgnoreStats() bool { + return false +} +func (*ErrMalformedHeader) String() string { return "header is malformed" } + +// ErrMessageTooLong indicates the operation encountered a message whose length +// exceeds the maximum permitted. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isError() {} + +// IgnoreStats implements Error. +func (*ErrMessageTooLong) IgnoreStats() bool { + return false +} +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrNetworkUnreachable indicates the operation is not able to reach the +// destination network. +// +// +stateify savable +type ErrNetworkUnreachable struct{} + +func (*ErrNetworkUnreachable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNetworkUnreachable) IgnoreStats() bool { + return false +} +func (*ErrNetworkUnreachable) String() string { return "network is unreachable" } + +// ErrNoBufferSpace indicates no buffer space is available. +// +// +stateify savable +type ErrNoBufferSpace struct{} + +func (*ErrNoBufferSpace) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoBufferSpace) IgnoreStats() bool { + return false +} +func (*ErrNoBufferSpace) String() string { return "no buffer space available" } + +// ErrNoPortAvailable indicates no port could be allocated for the operation. +// +// +stateify savable +type ErrNoPortAvailable struct{} + +func (*ErrNoPortAvailable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoPortAvailable) IgnoreStats() bool { + return false +} +func (*ErrNoPortAvailable) String() string { return "no ports are available" } + +// ErrNoRoute indicates the operation is not able to find a route to the +// destination. +// +// +stateify savable +type ErrNoRoute struct{} + +func (*ErrNoRoute) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoRoute) IgnoreStats() bool { + return false +} +func (*ErrNoRoute) String() string { return "no route" } + +// ErrNoSuchFile is used to indicate that ENOENT should be returned the to +// calling application. +// +// +stateify savable +type ErrNoSuchFile struct{} + +func (*ErrNoSuchFile) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoSuchFile) IgnoreStats() bool { + return false +} +func (*ErrNoSuchFile) String() string { return "no such file" } + +// ErrNotConnected indicates the endpoint is not connected. +// +// +stateify savable +type ErrNotConnected struct{} + +func (*ErrNotConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotConnected) IgnoreStats() bool { + return false +} +func (*ErrNotConnected) String() string { return "endpoint not connected" } + +// ErrNotPermitted indicates the operation is not permitted. +// +// +stateify savable +type ErrNotPermitted struct{} + +func (*ErrNotPermitted) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotPermitted) IgnoreStats() bool { + return false +} +func (*ErrNotPermitted) String() string { return "operation not permitted" } + +// ErrNotSupported indicates the operation is not supported. +// +// +stateify savable +type ErrNotSupported struct{} + +func (*ErrNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotSupported) IgnoreStats() bool { + return false +} +func (*ErrNotSupported) String() string { return "operation not supported" } + +// ErrPortInUse indicates the provided port is in use. +// +// +stateify savable +type ErrPortInUse struct{} + +func (*ErrPortInUse) isError() {} + +// IgnoreStats implements Error. +func (*ErrPortInUse) IgnoreStats() bool { + return false +} +func (*ErrPortInUse) String() string { return "port is in use" } + +// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size +// operation. +// +// +stateify savable +type ErrQueueSizeNotSupported struct{} + +func (*ErrQueueSizeNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrQueueSizeNotSupported) IgnoreStats() bool { + return false +} +func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" } + +// ErrTimeout indicates the operation timed out. +// +// +stateify savable +type ErrTimeout struct{} + +func (*ErrTimeout) isError() {} + +// IgnoreStats implements Error. +func (*ErrTimeout) IgnoreStats() bool { + return false +} +func (*ErrTimeout) String() string { return "operation timed out" } + +// ErrUnknownDevice indicates an unknown device identifier was provided. +// +// +stateify savable +type ErrUnknownDevice struct{} + +func (*ErrUnknownDevice) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownDevice) IgnoreStats() bool { + return false +} +func (*ErrUnknownDevice) String() string { return "unknown device" } + +// ErrUnknownNICID indicates an unknown NIC ID was provided. +// +// +stateify savable +type ErrUnknownNICID struct{} + +func (*ErrUnknownNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownNICID) IgnoreStats() bool { + return false +} +func (*ErrUnknownNICID) String() string { return "unknown nic id" } + +// ErrUnknownProtocol indicates an unknown protocol was requested. +// +// +stateify savable +type ErrUnknownProtocol struct{} + +func (*ErrUnknownProtocol) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocol) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocol) String() string { return "unknown protocol" } + +// ErrUnknownProtocolOption indicates an unknown protocol option was provided. +// +// +stateify savable +type ErrUnknownProtocolOption struct{} + +func (*ErrUnknownProtocolOption) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocolOption) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" } + +// ErrWouldBlock indicates the operation would block. +// +// +stateify savable +type ErrWouldBlock struct{} + +func (*ErrWouldBlock) isError() {} + +// IgnoreStats implements Error. +func (*ErrWouldBlock) IgnoreStats() bool { + return true +} +func (*ErrWouldBlock) String() string { return "operation would block" } diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD index 114d43df3..bb9d44aff 100644 --- a/pkg/tcpip/faketime/BUILD +++ b/pkg/tcpip/faketime/BUILD @@ -6,10 +6,7 @@ go_library( name = "faketime", srcs = ["faketime.go"], visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@com_github_dpjacques_clockwork//:go_default_library", - ], + deps = ["//pkg/tcpip"], ) go_test( diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index f7a4fbde1..fb819d7a8 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -17,10 +17,10 @@ package faketime import ( "container/heap" + "fmt" "sync" "time" - "github.com/dpjacques/clockwork" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -44,38 +44,85 @@ func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { return nil } +type notificationChannels struct { + mu struct { + sync.Mutex + + ch []<-chan struct{} + } +} + +func (n *notificationChannels) add(ch <-chan struct{}) { + n.mu.Lock() + defer n.mu.Unlock() + n.mu.ch = append(n.mu.ch, ch) +} + +// wait returns once all the notification channels are readable. +// +// Channels that are added while waiting on existing channels will be waited on +// as well. +func (n *notificationChannels) wait() { + for { + n.mu.Lock() + ch := n.mu.ch + n.mu.ch = nil + n.mu.Unlock() + + if len(ch) == 0 { + break + } + + for _, c := range ch { + <-c + } + } +} + // ManualClock implements tcpip.Clock and only advances manually with Advance // method. type ManualClock struct { - clock clockwork.FakeClock + // runningTimers tracks the completion of timer callbacks that began running + // immediately upon their scheduling. It is used to ensure the proper ordering + // of timer callback dispatch. + runningTimers notificationChannels + + mu struct { + sync.RWMutex - // mu protects the fields below. - mu sync.RWMutex + // now is the current (fake) time of the clock. + now time.Time - // times is min-heap of times. A heap is used for quick retrieval of the next - // upcoming time of scheduled work. - times *timeHeap + // times is min-heap of times. + times timeHeap - // waitGroups stores one WaitGroup for all work scheduled to execute at the - // same time via AfterFunc. This allows parallel execution of all functions - // passed to AfterFunc scheduled for the same time. - waitGroups map[time.Time]*sync.WaitGroup + // timers holds the timers scheduled for each time. + timers map[time.Time]map[*manualTimer]struct{} + } } // NewManualClock creates a new ManualClock instance. func NewManualClock() *ManualClock { - return &ManualClock{ - clock: clockwork.NewFakeClock(), - times: &timeHeap{}, - waitGroups: make(map[time.Time]*sync.WaitGroup), - } + c := &ManualClock{} + + c.mu.Lock() + defer c.mu.Unlock() + + // Set the initial time to a non-zero value since the zero value is used to + // detect inactive timers. + c.mu.now = time.Unix(0, 0) + c.mu.timers = make(map[time.Time]map[*manualTimer]struct{}) + + return c } var _ tcpip.Clock = (*ManualClock)(nil) // NowNanoseconds implements tcpip.Clock.NowNanoseconds. func (mc *ManualClock) NowNanoseconds() int64 { - return mc.clock.Now().UnixNano() + mc.mu.RLock() + defer mc.mu.RUnlock() + return mc.mu.now.UnixNano() } // NowMonotonic implements tcpip.Clock.NowMonotonic. @@ -85,128 +132,203 @@ func (mc *ManualClock) NowMonotonic() int64 { // AfterFunc implements tcpip.Clock.AfterFunc. func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := mc.clock.Now().Add(d) - wg := mc.addWait(until) - return &manualTimer{ + mt := &manualTimer{ clock: mc, - until: until, - timer: mc.clock.AfterFunc(d, func() { - defer wg.Done() - f() - }), + f: f, } -} -// addWait adds an additional wait to the WaitGroup for parallel execution of -// all work scheduled for t. Returns a reference to the WaitGroup modified. -func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { - mc.mu.RLock() - wg, ok := mc.waitGroups[t] - mc.mu.RUnlock() + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() - if ok { - wg.Add(1) - return wg + mc.resetTimerLocked(mt, d) + return mt +} + +// resetTimerLocked schedules a timer to be fired after the given duration. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) { + if !mt.mu.firesAt.IsZero() { + panic("tried to reset an active timer") } - mc.mu.Lock() - heap.Push(mc.times, t) - mc.mu.Unlock() + t := mc.mu.now.Add(d) - wg = &sync.WaitGroup{} - wg.Add(1) + if !mc.mu.now.Before(t) { + // If the timer is scheduled to fire immediately, call its callback + // in a new goroutine immediately. + // + // It needs to be called in its own goroutine to escape its current + // execution context - like an actual timer. + ch := make(chan struct{}) + mc.runningTimers.add(ch) - mc.mu.Lock() - mc.waitGroups[t] = wg - mc.mu.Unlock() + go func() { + defer close(ch) + + mt.f() + }() - return wg + return + } + + mt.mu.firesAt = t + + timers, ok := mc.mu.timers[t] + if !ok { + timers = make(map[*manualTimer]struct{}) + mc.mu.timers[t] = timers + heap.Push(&mc.mu.times, t) + } + + timers[mt] = struct{}{} } -// removeWait removes a wait from the WaitGroup for parallel execution of all -// work scheduled for t. -func (mc *ManualClock) removeWait(t time.Time) { - mc.mu.RLock() - defer mc.mu.RUnlock() +// stopTimerLocked stops a timer from firing. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { + t := mt.mu.firesAt + mt.mu.firesAt = time.Time{} + + if t.IsZero() { + panic("tried to stop an inactive timer") + } - wg := mc.waitGroups[t] - wg.Done() + timers, ok := mc.mu.timers[t] + if !ok { + err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt) + for t := range mc.mu.timers { + err += fmt.Sprintf("%s\n", t.UTC()) + } + panic(err) + } + + if _, ok := timers[mt]; !ok { + panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC())) + } + + delete(timers, mt) + + if len(timers) == 0 { + delete(mc.mu.timers, t) + } } // Advance executes all work that have been scheduled to execute within d from -// the current time. Blocks until all work has completed execution. +// the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { - // Block until all the work is done - until := mc.clock.Now().Add(d) - for { - mc.mu.Lock() - if mc.times.Len() == 0 { - mc.mu.Unlock() - break - } + // We spawn goroutines for timers that were scheduled to fire at the time of + // being reset. Wait for those goroutines to complete before proceeding so + // that timer callbacks are called in the right order. + mc.runningTimers.wait() - t := heap.Pop(mc.times).(time.Time) + mc.mu.Lock() + defer mc.mu.Unlock() + + until := mc.mu.now.Add(d) + for mc.mu.times.Len() > 0 { + t := heap.Pop(&mc.mu.times).(time.Time) if t.After(until) { // No work to do - heap.Push(mc.times, t) - mc.mu.Unlock() + heap.Push(&mc.mu.times, t) break } - mc.mu.Unlock() - diff := t.Sub(mc.clock.Now()) - mc.clock.Advance(diff) + timers := mc.mu.timers[t] + delete(mc.mu.timers, t) + + mc.mu.now = t + + // Mark the timers as inactive since they will be fired. + // + // This needs to be done while holding mc's lock because we remove the entry + // in the map of timers for the current time. If an attempt to stop a + // timer is made after mc's lock was dropped but before the timer is + // marked inactive, we would panic since no entry exists for the time when + // the timer was expected to fire. + for mt := range timers { + mt.mu.Lock() + mt.mu.firesAt = time.Time{} + mt.mu.Unlock() + } - mc.mu.RLock() - wg := mc.waitGroups[t] - mc.mu.RUnlock() + // Release the lock before calling the timer's callback fn since the + // callback fn might try to schedule a timer which requires obtaining + // mc's lock. + mc.mu.Unlock() - wg.Wait() + for mt := range timers { + mt.f() + } + // The timer callbacks may have scheduled a timer to fire immediately. + // We spawn goroutines for these timers and need to wait for them to + // finish before proceeding so that timer callbacks are called in the + // right order. + mc.runningTimers.wait() mc.mu.Lock() - delete(mc.waitGroups, t) - mc.mu.Unlock() } - if now := mc.clock.Now(); until.After(now) { - mc.clock.Advance(until.Sub(now)) + + mc.mu.now = until +} + +func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if !mt.mu.firesAt.IsZero() { + mc.stopTimerLocked(mt) } + + mc.resetTimerLocked(mt, d) +} + +func (mc *ManualClock) stopTimer(mt *manualTimer) bool { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if mt.mu.firesAt.IsZero() { + return false + } + + mc.stopTimerLocked(mt) + return true } type manualTimer struct { clock *ManualClock - timer clockwork.Timer + f func() - mu sync.RWMutex - until time.Time + mu struct { + sync.Mutex + + // firesAt is the time when the timer will fire. + // + // Zero only when the timer is not active. + firesAt time.Time + } } var _ tcpip.Timer = (*manualTimer)(nil) // Reset implements tcpip.Timer.Reset. -func (t *manualTimer) Reset(d time.Duration) { - if !t.timer.Reset(d) { - return - } - - t.mu.Lock() - defer t.mu.Unlock() - - t.clock.removeWait(t.until) - t.until = t.clock.clock.Now().Add(d) - t.clock.addWait(t.until) +func (mt *manualTimer) Reset(d time.Duration) { + mt.clock.resetTimer(mt, d) } // Stop implements tcpip.Timer.Stop. -func (t *manualTimer) Stop() bool { - if !t.timer.Stop() { - return false - } - - t.mu.RLock() - defer t.mu.RUnlock() - - t.clock.removeWait(t.until) - return true +func (mt *manualTimer) Stop() bool { + return mt.clock.stopTimer(mt) } type timeHeap []time.Time diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e6103f4bc..48ca60319 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -481,15 +480,13 @@ const ( IPv4OptionLengthOffset = 1 ) -// Potential errors when parsing generic IP options. -var ( - ErrIPv4OptZeroLength = errors.New("zero length IP option") - ErrIPv4OptDuplicate = errors.New("duplicate IP option") - ErrIPv4OptInvalid = errors.New("invalid IP option") - ErrIPv4OptMalformed = errors.New("malformed IP option") - ErrIPv4OptionTruncated = errors.New("truncated IP option") - ErrIPv4OptionAddress = errors.New("bad IP option address") -) +// IPv4OptParameterProblem indicates that a Parameter Problem message +// should be generated, and gives the offset in the current entity +// that should be used in that packet. +type IPv4OptParameterProblem struct { + Pointer uint8 + NeedICMP bool +} // IPv4Option is an interface representing various option types. type IPv4Option interface { @@ -583,8 +580,9 @@ func (i *IPv4OptionIterator) Finalize() IPv4Options { // It returns // - A slice of bytes holding the next option or nil if there is error. // - A boolean which is true if parsing of all the options is complete. -// - An error which is non-nil if an error condition was encountered. -func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { +// Undefined in the case of error. +// - An error indication which is non-nil if an error condition was found. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) { // The opts slice gets shorter as we process the options. When we have no // bytes left we are done. if len(i.options) == 0 { @@ -606,24 +604,22 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { // There are no more single byte options defined. All the rest have a length // field so we need to sanity check it. if len(i.options) == 1 { - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optLen := i.options[IPv4OptionLengthOffset] - if optLen == 0 { - i.ErrCursor++ - return nil, true, ErrIPv4OptZeroLength - } + if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) { + // The actual error is in the length (2nd byte of the option) but we + // return the start of the option for compatibility with Linux. - if optLen == 1 { - i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed - } - - if optLen > uint8(len(i.options)) { - i.ErrCursor++ - return nil, true, ErrIPv4OptionTruncated + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } optionBody := i.options[:optLen] @@ -635,7 +631,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionTimestampType: if optLen < IPv4OptionTimestampHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionTimestamp(optionBody) return &retval, false, nil @@ -643,7 +642,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { case IPv4OptionRecordRouteType: if optLen < IPv4OptionRecordRouteHdrLength { i.ErrCursor++ - return nil, true, ErrIPv4OptMalformed + return nil, false, &IPv4OptParameterProblem{ + Pointer: i.ErrCursor, + NeedICMP: true, + } } retval := IPv4OptionRecordRoute(optionBody) return &retval, false, nil diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 5580d6a78..f2403978c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -453,9 +453,9 @@ const ( ) // ScopeForIPv6Address returns the scope for an IPv6 address. -func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { if len(addr) != IPv6AddressSize { - return GlobalScope, tcpip.ErrBadAddress + return GlobalScope, &tcpip.ErrBadAddress{} } switch { diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index e3fbd64f3..f10f446a6 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -299,7 +299,7 @@ func TestScopeForIPv6Address(t *testing.T) { name string addr tcpip.Address scope header.IPv6AddressScope - err *tcpip.Error + err tcpip.Error }{ { name: "Unique Local", @@ -329,15 +329,15 @@ func TestScopeForIPv6Address(t *testing.T) { name: "IPv4", addr: "\x01\x02\x03\x04", scope: header.GlobalScope, - err: tcpip.ErrBadAddress, + err: &tcpip.ErrBadAddress{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { got, err := header.ScopeForIPv6Address(test.addr) - if err != test.err { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + if diff := cmp.Diff(test.err, err); diff != "" { + t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff) } if got != test.scope { t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a068d93a4..cd76272de 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, @@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 2f2d9d4ac..d873766a6 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 10072eac1..ae1394ebf 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -35,7 +35,6 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f86c383d8..0164d851b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -57,7 +57,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { - dispatch() (bool, *tcpip.Error) + dispatch() (bool, tcpip.Error) } // PacketDispatchMode are the various supported methods of receiving and @@ -118,7 +118,7 @@ type endpoint struct { // closed is a function to be called when the FD's peer (if any) closes // its end of the communication pipe. - closed func(*tcpip.Error) + closed func(tcpip.Error) inboundDispatchers []linkDispatcher dispatcher stack.NetworkDispatcher @@ -149,7 +149,7 @@ type Options struct { // ClosedFunc is a function to be called when an endpoint's peer (if // any) closes its end of the communication pipe. - ClosedFunc func(*tcpip.Error) + ClosedFunc func(tcpip.Error) // Address is the link address for this endpoint. Only used if // EthernetHeader is true. @@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { @@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. -func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { for { cont, err := inboundDispatcher.dispatch() if err != nil || !cont { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 90da22d34..e82371798 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -96,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context { } done := make(chan struct{}, 2) - opt.ClosedFunc = func(*tcpip.Error) { + opt.ClosedFunc = func(tcpip.Error) { done <- struct{}{} } @@ -465,67 +464,85 @@ var capLengthTestCases = []struct { config: []int{1, 2, 3}, n: 3, wantUsed: 2, - wantLengths: []int{1, 2, 3}, + wantLengths: []int{1, 2}, }, } -func TestReadVDispatcherCapLength(t *testing.T) { +func TestIovecBuffer(t *testing.T) { for _, c := range capLengthTestCases { - // fd does not matter for this test. - d := readVDispatcher{fd: -1, e: &endpoint{}} - d.views = make([]buffer.View, len(c.config)) - d.iovecs = make([]syscall.Iovec, len(c.config)) - d.allocateViews(c.config) - - used := d.capViews(c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views)) - for i, v := range d.views { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } - } -} + t.Run(c.comment, func(t *testing.T) { + b := newIovecBuffer(c.config, false /* skipsVnetHdr */) -func TestRecvMMsgDispatcherCapLength(t *testing.T) { - for _, c := range capLengthTestCases { - d := recvMMsgDispatcher{ - fd: -1, // fd does not matter for this test. - e: &endpoint{}, - views: make([][]buffer.View, 1), - iovecs: make([][]syscall.Iovec, 1), - msgHdrs: make([]rawfile.MMsgHdr, 1), - } + // Test initial allocation. + iovecs := b.nextIovecs() + if got, want := len(iovecs), len(c.config); got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } - for i := range d.views { - d.views[i] = make([]buffer.View, len(c.config)) - } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, len(c.config)) - } - for k, msgHdr := range d.msgHdrs { - msgHdr.Msg.Iov = &d.iovecs[k][0] - msgHdr.Msg.Iovlen = uint64(len(c.config)) - } + // Make a copy as iovecs points to internal slice. We will need this state + // later. + oldIovecs := append([]syscall.Iovec(nil), iovecs...) - d.allocateViews(c.config) + // Test the views that get pulled. + vv := b.pullViews(c.n) + var lengths []int + for _, v := range vv.Views() { + lengths = append(lengths, len(v)) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths) + } - used := d.capViews(0, c.n, c.config) - if used != c.wantUsed { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) - } - lengths := make([]int, len(d.views[0])) - for i, v := range d.views[0] { - lengths[i] = len(v) - } - if !reflect.DeepEqual(lengths, c.wantLengths) { - t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) - } + // Test that new views get reallocated. + for i, newIov := range b.nextIovecs() { + if i < c.wantUsed { + if newIov.Base == oldIovecs[i].Base { + t.Errorf("b.views[%d] should have been reallocated", i) + } + } else { + if newIov.Base != oldIovecs[i].Base { + t.Errorf("b.views[%d] should not have been reallocated", i) + } + } + } + }) + } +} +func TestIovecBufferSkipVnetHdr(t *testing.T) { + for _, test := range []struct { + desc string + readN int + wantLen int + }{ + { + desc: "nothing read", + readN: 0, + wantLen: 0, + }, + { + desc: "smaller than vnet header", + readN: virtioNetHdrSize - 1, + wantLen: 0, + }, + { + desc: "header skipped", + readN: virtioNetHdrSize + 100, + wantLen: 100, + }, + } { + t.Run(test.desc, func(t *testing.T) { + b := newIovecBuffer([]int{10, 20, 50, 50}, true) + // Pretend a read happend. + b.nextIovecs() + vv := b.pullViews(test.readN) + if got, want := vv.Size(), test.wantLen; got != want { + t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want) + } + if got, want := len(vv.ToOwnedView()), test.wantLen; got != want { + t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want) + } + }) } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index c475dda20..a2b63fe6b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -129,7 +129,7 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { event := rawfile.PollEvent{ @@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { pkt, err := d.readMMappedPacket() if err != nil { return false, err diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 8c3ca86d6..ecae1ad2d 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -29,92 +29,124 @@ import ( // BufConfig defines the shape of the vectorised view used to read packets from the NIC. var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} -// readVDispatcher uses readv() system call to read inbound packets and -// dispatches them. -type readVDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - +type iovecBuffer struct { // views are the actual buffers that hold the packet contents. views []buffer.View // iovecs are initialized with base pointers/len of the corresponding - // entries in the views defined above, except when GSO is enabled then - // the first iovec points to a buffer for the vnet header which is - // stripped before the views are passed up the stack for further + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further // processing. iovecs []syscall.Iovec + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // skipsVnetHdr is true if virtioNetHdr is to skipped. + skipsVnetHdr bool } -func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} - d.views = make([]buffer.View, len(BufConfig)) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - iovLen++ +func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer { + b := &iovecBuffer{ + views: make([]buffer.View, len(sizes)), + sizes: sizes, + skipsVnetHdr: skipsVnetHdr, } - d.iovecs = make([]syscall.Iovec, iovLen) - return d, nil + niov := len(b.views) + if b.skipsVnetHdr { + niov++ + } + b.iovecs = make([]syscall.Iovec, niov) + return b } -func (d *readVDispatcher) allocateViews(bufConfig []int) { - var vnetHdr [virtioNetHdrSize]byte +func (b *iovecBuffer) nextIovecs() []syscall.Iovec { vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if b.skipsVnetHdr { + var vnetHdr [virtioNetHdrSize]byte // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. - d.iovecs[0] = syscall.Iovec{ + b.iovecs[0] = syscall.Iovec{ Base: &vnetHdr[0], Len: uint64(virtioNetHdrSize), } vnetHdrOff++ } - for i := 0; i < len(bufConfig); i++ { - if d.views[i] != nil { + for i := range b.views { + if b.views[i] != nil { break } - b := buffer.NewView(bufConfig[i]) - d.views[i] = b - d.iovecs[i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), + v := buffer.NewView(b.sizes[i]) + b.views[i] = v + b.iovecs[i+vnetHdrOff] = syscall.Iovec{ + Base: &v[0], + Len: uint64(len(v)), } } + return b.iovecs } -func (d *readVDispatcher) capViews(n int, buffers []int) int { +func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { + var views []buffer.View c := 0 - for i, s := range buffers { - c += s + if b.skipsVnetHdr { + c += virtioNetHdrSize if c >= n { - d.views[i].CapLength(s - (c - n)) - return i + 1 + // Nothing in the packet. + return buffer.NewVectorisedView(0, nil) + } + } + for i, v := range b.views { + c += len(v) + if c >= n { + b.views[i].CapLength(len(v) - (c - n)) + views = append([]buffer.View(nil), b.views[:i+1]...) + break } } - return len(buffers) + // Remove the first len(views) used views from the state. + for i := range views { + b.views[i] = nil + } + if b.skipsVnetHdr { + // Exclude the size of the vnet header. + n -= virtioNetHdrSize + } + return buffer.NewVectorisedView(n, views) } -// dispatch reads one packet from the file descriptor and dispatches it. -func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +// readVDispatcher uses readv() system call to read inbound packets and +// dispatches them. +type readVDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // buf is the iovec buffer that contains the packet contents. + buf *iovecBuffer +} + +func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + d := &readVDispatcher{fd: fd, e: e} + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) + return d, nil +} - n, err := rawfile.BlockingReadv(d.fd, d.iovecs) +// dispatch reads one packet from the file descriptor and dispatches it. +func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { + n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) if n == 0 || err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // Skip virtioNetHdr which is added before each packet, it - // isn't used and it isn't in a view. - n -= virtioNetHdrSize - } - used := d.capViews(n, BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + Data: d.buf.pullViews(n), }) var ( @@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + return true, nil + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: @@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[i] = nil - } - return true, nil } @@ -162,15 +194,8 @@ type recvMMsgDispatcher struct { // e is the endpoint this dispatcher is attached to. e *endpoint - // views is an array of array of buffers that contain packet contents. - views [][]buffer.View - - // iovecs is an array of array of iovec records where each iovec base - // pointer and length are initialzed to the corresponding view above, - // except when GSO is enabled then the first iovec in each array of - // iovecs points to a buffer for the vnet header which is stripped - // before the views are passed up the stack for further processing. - iovecs [][]syscall.Iovec + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to // reference an array of iovecs in the iovecs field defined above. This @@ -187,74 +212,32 @@ const ( func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &recvMMsgDispatcher{ - fd: fd, - e: e, - } - d.views = make([][]buffer.View, MaxMsgsPerRecv) - for i := range d.views { - d.views[i] = make([]buffer.View, len(BufConfig)) - } - d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) - iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // virtioNetHdr is prepended before each packet. - iovLen++ + fd: fd, + e: e, + bufs: make([]*iovecBuffer, MaxMsgsPerRecv), + msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - for i := range d.iovecs { - d.iovecs[i] = make([]syscall.Iovec, iovLen) - } - d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv) - for i := range d.msgHdrs { - d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0] - d.msgHdrs[i].Msg.Iovlen = uint64(iovLen) + skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } return d, nil } -func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int { - c := 0 - for i, s := range buffers { - c += s - if c >= n { - d.views[k][i].CapLength(s - (c - n)) - return i + 1 - } - } - return len(buffers) -} - -func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { - for k := 0; k < len(d.views); k++ { - var vnetHdr [virtioNetHdrSize]byte - vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - // The kernel adds virtioNetHdr before each packet, but - // we don't use it, so so we allocate a buffer for it, - // add it in iovecs but don't add it in a view. - d.iovecs[k][0] = syscall.Iovec{ - Base: &vnetHdr[0], - Len: uint64(virtioNetHdrSize), - } - vnetHdrOff++ - } - for i := 0; i < len(bufConfig); i++ { - if d.views[k][i] != nil { - break - } - b := buffer.NewView(bufConfig[i]) - d.views[k][i] = b - d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{ - Base: &b[0], - Len: uint64(len(b)), - } - } - } -} - // recvMMsgDispatch reads more than one packet at a time from the file // descriptor and dispatches it. -func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { - d.allocateViews(BufConfig) +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].Len = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.Iovlen = uint64(iovLen) + } nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) if err != nil { @@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - n -= virtioNetHdrSize - } - used := d.capViews(k, int(n), BufConfig) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + Data: d.bufs[k].pullViews(n), }) + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress @@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } else { // We don't get any indication of what the packet is, so try to guess // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(d.views[k][0]) { + // IP version information is at the first octet, so pulling up 1 byte. + h, ok := pkt.Data.PullUp(1) + if !ok { + // Skip this packet. + continue + } + switch header.IPVersion(h) { case header.IPv4Version: p = header.IPv4ProtocolNumber case header.IPv6Version: p = header.IPv6ProtocolNumber default: - return true, nil + // Skip this packet. + continue } } d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) - - // Prepare e.views for another packet: release used views. - for i := 0; i < used; i++ { - d.views[k][i] = nil - } - } - - for k := 0; k < nMsgs; k++ { - d.msgHdrs[k].Len = 0 } return true, nil diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ac6a6be87..691467870 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 316f508e6..668f72eee 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } return endpoint.WritePackets(r, gso, pkts, protocol) } @@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { endpoint, ok := m.routes[dest] if !ok { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } return endpoint.InjectOutbound(dest, packet) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 814a54f23..97ad9fdd5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return e.child.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return e.child.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index c95cdd681..6cbe18a56 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index d6e83a414..bbe84f220 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -45,12 +45,7 @@ type Endpoint struct { linkAddr tcpip.LinkAddress } -// WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - if !e.linked.IsAttached() { - return nil - } - +func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) { // Note that the local address from the perspective of this endpoint is the // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this @@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw // // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send // and receive queues. - go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + go func() { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + } + }() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + if e.linked.IsAttached() { + var pkts stack.PacketBufferList + pkts.PushBack(pkt) + e.deliverPackets(r, proto, pkts) + } return nil } // WritePackets implements stack.LinkEndpoint. -func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - panic("not implemented") +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + if e.linked.IsAttached() { + e.deliverPackets(r, proto, pkts) + } + + return pkts.Len(), nil } // Attach implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 87035b034..128ef6e87 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. pkt.EgressRoute = r @@ -158,26 +158,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { - return tcpip.ErrNoBufferSpace + return &tcpip.ErrNoBufferSpace{} } d.newPacketWaker.Assert() return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { - pkt.EgressRoute = r - pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() } - return enqueued, tcpip.ErrNoBufferSpace + return enqueued, &tcpip.ErrNoBufferSpace{} } pkt = nxt enqueued++ diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6c410c5a6..e1047da50 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -27,5 +27,6 @@ go_test( library = "rawfile", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 604868fd8..406b97709 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -17,7 +17,6 @@ package rawfile import ( - "fmt" "syscall" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,48 +24,54 @@ import ( const maxErrno = 134 -var translations [maxErrno]*tcpip.Error - // TranslateErrno translate an errno from the syscall package into a -// *tcpip.Error. +// tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). -func TranslateErrno(e syscall.Errno) *tcpip.Error { - if e > 0 && e < syscall.Errno(len(translations)) { - if err := translations[e]; err != nil { - return err - } - } - return tcpip.ErrInvalidEndpointState -} - -func addTranslation(host syscall.Errno, trans *tcpip.Error) { - if translations[host] != nil { - panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) +// *tcpip.ErrInvalidEndpointState (EINVAL). +func TranslateErrno(e syscall.Errno) tcpip.Error { + switch e { + case syscall.EEXIST: + return &tcpip.ErrDuplicateAddress{} + case syscall.ENETUNREACH: + return &tcpip.ErrNoRoute{} + case syscall.EINVAL: + return &tcpip.ErrInvalidEndpointState{} + case syscall.EALREADY: + return &tcpip.ErrAlreadyConnecting{} + case syscall.EISCONN: + return &tcpip.ErrAlreadyConnected{} + case syscall.EADDRINUSE: + return &tcpip.ErrPortInUse{} + case syscall.EADDRNOTAVAIL: + return &tcpip.ErrBadLocalAddress{} + case syscall.EPIPE: + return &tcpip.ErrClosedForSend{} + case syscall.EWOULDBLOCK: + return &tcpip.ErrWouldBlock{} + case syscall.ECONNREFUSED: + return &tcpip.ErrConnectionRefused{} + case syscall.ETIMEDOUT: + return &tcpip.ErrTimeout{} + case syscall.EINPROGRESS: + return &tcpip.ErrConnectStarted{} + case syscall.EDESTADDRREQ: + return &tcpip.ErrDestinationRequired{} + case syscall.ENOTSUP: + return &tcpip.ErrNotSupported{} + case syscall.ENOTTY: + return &tcpip.ErrQueueSizeNotSupported{} + case syscall.ENOTCONN: + return &tcpip.ErrNotConnected{} + case syscall.ECONNRESET: + return &tcpip.ErrConnectionReset{} + case syscall.ECONNABORTED: + return &tcpip.ErrConnectionAborted{} + case syscall.EMSGSIZE: + return &tcpip.ErrMessageTooLong{} + case syscall.ENOBUFS: + return &tcpip.ErrNoBufferSpace{} + default: + return &tcpip.ErrInvalidEndpointState{} } - translations[host] = trans -} - -func init() { - addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) - addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) - addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) - addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) - addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) - addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) - addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) - addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) - addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) - addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) - addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) - addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) - addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) - addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) - addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) - addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) - addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) - addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) - addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) - addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go index e4cdc66bd..61aea1744 100644 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -20,34 +20,35 @@ import ( "syscall" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) func TestTranslateErrno(t *testing.T) { for _, test := range []struct { errno syscall.Errno - translated *tcpip.Error + translated tcpip.Error }{ { errno: syscall.Errno(0), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(maxErrno), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(514), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.EEXIST, - translated: tcpip.ErrDuplicateAddress, + translated: &tcpip.ErrDuplicateAddress{}, }, } { got := TranslateErrno(test.errno) - if got != test.translated { - t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + if diff := cmp.Diff(test.translated, got); diff != "" { + t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) } } } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index f4c32c2da..06f3ee21e 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) { // NonBlockingWrite writes the given buffer to a file descriptor. It fails if // partial data is written. -func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { +func NonBlockingWrite(fd int, buf []byte) tcpip.Error { var ptr unsafe.Pointer if len(buf) > 0 { ptr = unsafe.Pointer(&buf[0]) @@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { // NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. // It fails if partial data is written. -func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error { iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { @@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { } // NonBlockingSendMMsg sends multiple messages on a socket. -func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e != 0 { return 0, TranslateErrno(e) @@ -97,7 +97,7 @@ type PollEvent struct { // BlockingRead reads from a file descriptor that is set up as non-blocking. If // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. -func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { +func BlockingRead(fd int, b []byte) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { @@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { // BlockingReadv reads from a file descriptor that is set up as non-blocking and // stores the data in a list of iovecs buffers. If no data is available, it will // block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { @@ -149,7 +149,7 @@ type MMsgHdr struct { // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor // becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e == 0 { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6c937c858..2599bc406 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N e.mu.Unlock() if !ok { - return tcpip.ErrWouldBlock + return &tcpip.ErrWouldBlock{} } return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 23242b9e0..d480ad656 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) { Data: buf.ToVectorisedView(), }) err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 5859851d8..bd2b8d4bf 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket(directionSend, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index bfac358f4..3829ca9c9 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ Name: endpoint.name, }) - switch err { + switch err.(type) { case nil: return endpoint, nil - case tcpip.ErrDuplicateNICID: + case *tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 30f1ad540..20259b285 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index b139de7dd..e368a9eaa 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9ebf31b78..0caa65251 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -25,5 +25,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 8a6bcfc2c..c7ab876bf 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -4,9 +4,13 @@ package(licenses = ["notice"]) go_library( name = "arp", - srcs = ["arp.go"], + srcs = [ + "arp.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -33,3 +37,15 @@ go_test( "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":arp", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3259d052f..7838cc753 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -19,8 +19,10 @@ package arp import ( "fmt" + "reflect" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,11 +52,12 @@ type endpoint struct { nic stack.NetworkInterface linkAddrCache stack.LinkAddressCache nud stack.NUDHandler + stats sharedStats } -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } e.setEnabled(true) @@ -98,10 +101,12 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.protocol.forgetEndpoint(e.nic.ID()) +} -func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -110,51 +115,45 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { + return 0, &tcpip.ErrNotSupported{} } -func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats().ARP - stats.PacketsReceived.Increment() + stats := e.stats.arp + stats.packetsReceived.Increment() if !e.isEnabled() { - stats.DisabledPacketsReceived.Increment() + stats.disabledPacketsReceived.Increment() return } h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { - stats.MalformedPacketsReceived.Increment() + stats.malformedPacketsReceived.Increment() return } switch h.Op() { case header.ARPRequest: - stats.RequestsReceived.Increment() + stats.requestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.requestsReceivedUnknownTargetAddress.Increment() + return // we have no useful answer, ignore the request + } + + remoteAddr := tcpip.Address(h.ProtocolAddressSender()) + remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + if e.nud == nil { - if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { - stats.RequestsReceivedUnknownTargetAddress.Increment() - return // we have no useful answer, ignore the request - } - - addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) } else { - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { - stats.RequestsReceivedUnknownTargetAddress.Increment() - return // we have no useful answer, ignore the request - } - - remoteAddr := tcpip.Address(h.ProtocolAddressSender()) - remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } @@ -186,18 +185,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { - stats.OutgoingRepliesDropped.Increment() + stats.outgoingRepliesDropped.Increment() } else { - stats.OutgoingRepliesSent.Increment() + stats.outgoingRepliesSent.Increment() } case header.ARPReply: - stats.RepliesReceived.Increment() + stats.repliesReceived.Increment() addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(addr, linkAddr) return } @@ -216,9 +215,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + +var _ stack.NetworkProtocol = (*protocol)(nil) +var _ stack.LinkAddressResolver = (*protocol)(nil) + // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack + + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve the correct + // endpoint depending on the NIC. + eps map[tcpip.NICID]*endpoint + } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -236,39 +251,62 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { - stats := p.stack.Stats().ARP +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + stats := netEP.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } - nicID := nic.ID() if len(localAddr) == 0 { - addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) - if err != nil { - stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment() - return err + addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if !ok { + return &tcpip.ErrUnknownNICID{} } if len(addr.Address) == 0 { - stats.OutgoingRequestNetworkUnreachableErrors.Increment() - return tcpip.ErrNetworkUnreachable + stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment() + return &tcpip.ErrNetworkUnreachable{} } localAddr = addr.Address } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - stats.OutgoingRequestBadLocalAddressErrors.Increment() - return tcpip.ErrBadLocalAddress + stats.outgoingRequestBadLocalAddressErrors.Increment() + return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -288,10 +326,10 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { - stats.OutgoingRequestsDropped.Increment() + stats.outgoingRequestsDropped.Increment() return err } - stats.OutgoingRequestsSent.Increment() + stats.outgoingRequestsSent.Increment() return nil } @@ -307,13 +345,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. @@ -329,5 +367,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // NewProtocol returns an ARP network protocol. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { - return &protocol{stack: s} + return &protocol{ + stack: s, + mu: struct { + sync.RWMutex + eps map[tcpip.NICID]*endpoint + }{eps: make(map[tcpip.NICID]*endpoint)}, + } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 0536e1698..b0f07aa44 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) } case <-ctx.Done(): return fmt.Errorf("%s for %s", ctx.Err(), want) @@ -537,7 +537,7 @@ type testInterface struct { nicID tcpip.NICID - writeErr *tcpip.Error + writeErr tcpip.Error } func (t *testInterface) ID() tcpip.NICID { @@ -560,15 +560,15 @@ func (*testInterface) Promiscuous() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } @@ -585,104 +585,122 @@ func TestLinkAddressRequest(t *testing.T) { testAddr := tcpip.Address([]byte{1, 2, 3, 4}) tests := []struct { - name string - nicAddr tcpip.Address - localAddr tcpip.Address - remoteLinkAddr tcpip.LinkAddress - - linkErr *tcpip.Error - expectedErr *tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress - expectedRequestsSent uint64 - expectedRequestBadLocalAddressErrors uint64 - expectedRequestNetworkUnreachableErrors uint64 - expectedRequestDroppedErrors uint64 + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + linkErr tcpip.Error + expectedErr tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + expectedRequestsSent uint64 + expectedRequestBadLocalAddressErrors uint64 + expectedRequestInterfaceHasNoLocalAddressErrors uint64 + expectedRequestDroppedErrors uint64 }{ { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, - expectedRequestsSent: 1, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unspecified source", + nicAddr: stackAddr, + localAddr: "", + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Unicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: &tcpip.ErrBadLocalAddress{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 1, - expectedRequestNetworkUnreachableErrors: 0, + name: "Multicast with unassigned address", + nicAddr: stackAddr, + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: &tcpip.ErrBadLocalAddress{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 0, }, { - name: "Unicast with no local address available", - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Unicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: remoteLinkAddr, + expectedErr: &tcpip.ErrNetworkUnreachable{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, - expectedRequestsSent: 0, - expectedRequestBadLocalAddressErrors: 0, - expectedRequestNetworkUnreachableErrors: 1, + name: "Multicast with no local address available", + nicAddr: "", + localAddr: "", + remoteLinkAddr: "", + expectedErr: &tcpip.ErrNetworkUnreachable{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 1, + expectedRequestDroppedErrors: 0, }, { - name: "Link error", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - linkErr: tcpip.ErrInvalidEndpointState, - expectedErr: tcpip.ErrInvalidEndpointState, - expectedRequestDroppedErrors: 1, + name: "Link error", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + linkErr: &tcpip.ErrInvalidEndpointState{}, + expectedErr: &tcpip.ErrInvalidEndpointState{}, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestInterfaceHasNoLocalAddressErrors: 0, + expectedRequestDroppedErrors: 1, }, } @@ -714,19 +732,20 @@ func TestLinkAddressRequest(t *testing.T) { // link endpoint even though the stack uses the real NIC to validate the // local address. iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) } + if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors) + } if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) } - if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors { - t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors) - } if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) } @@ -774,11 +793,8 @@ func TestLinkAddressRequestWithoutNIC(t *testing.T) { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") } - if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID) - } - - if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 { - t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got) + err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}) + if _, ok := err.(*tcpip.ErrNotConnected); !ok { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{}) } } diff --git a/pkg/tcpip/network/arp/stats.go b/pkg/tcpip/network/arp/stats.go new file mode 100644 index 000000000..6d7194c6c --- /dev/null +++ b/pkg/tcpip/network/arp/stats.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to ARP. +type Stats struct { + // ARP holds ARP statistics. + ARP tcpip.ARPStats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +type sharedStats struct { + localStats Stats + arp multiCounterARPStats +} + +// LINT.IfChange(multiCounterARPStats) + +type multiCounterARPStats struct { + packetsReceived tcpip.MultiCounterStat + disabledPacketsReceived tcpip.MultiCounterStat + malformedPacketsReceived tcpip.MultiCounterStat + requestsReceived tcpip.MultiCounterStat + requestsReceivedUnknownTargetAddress tcpip.MultiCounterStat + outgoingRequestInterfaceHasNoLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestBadLocalAddressErrors tcpip.MultiCounterStat + outgoingRequestsDropped tcpip.MultiCounterStat + outgoingRequestsSent tcpip.MultiCounterStat + repliesReceived tcpip.MultiCounterStat + outgoingRepliesDropped tcpip.MultiCounterStat + outgoingRepliesSent tcpip.MultiCounterStat +} + +func (m *multiCounterARPStats) init(a, b *tcpip.ARPStats) { + m.packetsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.disabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.malformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.requestsReceived.Init(a.RequestsReceived, b.RequestsReceived) + m.requestsReceivedUnknownTargetAddress.Init(a.RequestsReceivedUnknownTargetAddress, b.RequestsReceivedUnknownTargetAddress) + m.outgoingRequestInterfaceHasNoLocalAddressErrors.Init(a.OutgoingRequestInterfaceHasNoLocalAddressErrors, b.OutgoingRequestInterfaceHasNoLocalAddressErrors) + m.outgoingRequestBadLocalAddressErrors.Init(a.OutgoingRequestBadLocalAddressErrors, b.OutgoingRequestBadLocalAddressErrors) + m.outgoingRequestsDropped.Init(a.OutgoingRequestsDropped, b.OutgoingRequestsDropped) + m.outgoingRequestsSent.Init(a.OutgoingRequestsSent, b.OutgoingRequestsSent) + m.repliesReceived.Init(a.RepliesReceived, b.RepliesReceived) + m.outgoingRepliesDropped.Init(a.OutgoingRepliesDropped, b.OutgoingRepliesDropped) + m.outgoingRepliesSent.Init(a.OutgoingRepliesSent, b.OutgoingRepliesSent) +} + +// LINT.ThenChange(../../tcpip.go:ARPStats) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go new file mode 100644 index 000000000..036fdf739 --- /dev/null +++ b/pkg/tcpip/network/arp/stats_test.go @@ -0,0 +1,93 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) + } + + ep.Close() + + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index ca1247c1e..411bca25d 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ip", - srcs = ["generic_multicast_protocol.go"], + srcs = [ + "generic_multicast_protocol.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index f2f0e069c..a81f5c8c3 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -174,10 +174,10 @@ type MulticastGroupProtocol interface { // // Returns false if the caller should queue the report to be sent later. Note, // returning false does not mean that the receiver hit an error. - SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error) // SendLeave sends a multicast leave for the specified group address. - SendLeave(groupAddress tcpip.Address) *tcpip.Error + SendLeave(groupAddress tcpip.Address) tcpip.Error } // GenericMulticastProtocolState is the per interface generic multicast protocol diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 85593f211..d5d5a449e 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) @@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go new file mode 100644 index 000000000..898f8b356 --- /dev/null +++ b/pkg/tcpip/network/ip/stats.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import "gvisor.dev/gvisor/pkg/tcpip" + +// LINT.IfChange(MultiCounterIPStats) + +// MultiCounterIPStats holds IP statistics, each counter may have several +// versions. +type MultiCounterIPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer. + PacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived tcpip.MultiCounterStat + + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived tcpip.MultiCounterStat + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived tcpip.MultiCounterStat + + // PacketsDelivered is the total number of incoming IP packets that are + // successfully delivered to the transport layer. + PacketsDelivered tcpip.MultiCounterStat + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent tcpip.MultiCounterStat + + // OutgoingPacketErrors is the total number of IP packets which failed to + // write to a link-layer endpoint. + OutgoingPacketErrors tcpip.MultiCounterStat + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived tcpip.MultiCounterStat + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived tcpip.MultiCounterStat + + // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // Prerouting chain. + IPTablesPreroutingDropped tcpip.MultiCounterStat + + // IPTablesInputDropped is the total number of IP packets dropped in the Input + // chain. + IPTablesInputDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the total number of IP packets dropped in the + // Output chain. + IPTablesOutputDropped tcpip.MultiCounterStat + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived tcpip.MultiCounterStat + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived tcpip.MultiCounterStat + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { + m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) + m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) + m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered) + m.PacketsSent.Init(a.PacketsSent, b.PacketsSent) + m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors) + m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) + m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) + m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived) + m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived) + m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) +} + +// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 3005973d7..47cce79bb 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -18,6 +18,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -167,7 +168,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -189,7 +190,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } @@ -203,7 +204,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -219,7 +220,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -235,14 +236,14 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { +func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - e := channel.New(0, 1280, "") + e := channel.New(1, mtu, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -263,7 +264,7 @@ func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpo func buildDummyStack(t *testing.T) *stack.Stack { t.Helper() - s, _ := buildDummyStackWithLinkEndpoint(t) + s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) return s } @@ -306,8 +307,8 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func TestSourceAddressValidation(t *testing.T) { @@ -416,7 +417,7 @@ func TestSourceAddressValidation(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s, e := buildDummyStackWithLinkEndpoint(t) + s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU) test.rxICMP(e, test.srcAddress) var wantValid uint64 @@ -479,8 +480,9 @@ func TestEnableWhenNICDisabled(t *testing.T) { // Attempting to enable the endpoint while the NIC is disabled should // fail. nic.setEnabled(false) - if err := ep.Enable(); err != tcpip.ErrNotPermitted { - t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + err := ep.Enable() + if _, ok := err.(*tcpip.ErrNotPermitted); !ok { + t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) } // ep should consider the NIC's enabled status when determining its own // enabled status so we "enable" the NIC to read just the endpoint's @@ -1122,7 +1124,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr *tcpip.Error + expectedErr tcpip.Error }{ { name: "IPv4", @@ -1187,7 +1189,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 too small", @@ -1205,7 +1207,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 minimum size", @@ -1465,7 +1467,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, } @@ -1490,7 +1492,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, }) - e := channel.New(1, 1280, "") + e := channel.New(1, header.IPv6MinimumMTU, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -1506,10 +1508,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } defer r.Release() - if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })); err != test.expectedErr { - t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + { + err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr), + })) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) + } } if test.expectedErr != nil { @@ -1526,3 +1531,246 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) } } + +// Test that the included data in an ICMP error packet conforms to the +// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3 +func TestICMPInclusionSize(t *testing.T) { + const ( + replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize + replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize + targetSize4 = header.IPv4MinimumProcessableDatagramSize + targetSize6 = header.IPv6MinimumMTU + // A protocol number that will cause an error response. + reservedProtocol = 254 + ) + + // IPv4 function to create a IP packet and send it to the stack. + // The packet should generate an error response. We can do that by using an + // unknown transport protocol (254). + rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + totalLen := header.IPv4MinimumSize + len(payload) + hdr := buffer.NewPrependable(header.IPv4MinimumSize) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: reservedProtocol, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + // IPv6 function to create a packet and send it to the stack. + // The packet should be errant in a way that causes the stack to send an + // ICMP error response and have enough data to allow the testing of the + // inclusion of the errant packet. Use `unknown next header' to generate + // the error. + rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload)), + TransportProtocol: reservedProtocol, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + vv := hdr.View().ToVectorisedView() + vv.AppendView(buffer.View(payload)) + // Take a copy before InjectInbound takes ownership of vv + // as vv may be changed during the call. + v := vv.ToView() + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + return v + } + + v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv4Addr), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4ProtoUnreachable), + checker.ICMPv4Payload(payload[:expectedPayloadLength]), + ), + ) + } + + v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) { + // We already know the entire packet is the right size so we can use its + // length to calculate the right payload size to check. + expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize + checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()), + checker.SrcAddr(localIPv6Addr), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6ParamProblem), + checker.ICMPv6Code(header.ICMPv6UnknownHeader), + checker.ICMPv6Payload(payload[:expectedPayloadLength]), + ), + ) + } + tests := []struct { + name string + srcAddress tcpip.Address + injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View + checker func(*testing.T, *stack.PacketBuffer, buffer.View) + payloadLength int // Not including IP header. + linkMTU uint32 // Largest IP packet that the link can send as payload. + replyLength int // Total size of IP/ICMP packet expected back. + }{ + { + name: "IPv4 exact match", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 larger MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 + 1000, + replyLength: targetSize4, + }, + { + name: "IPv4 smaller MTU", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4, + linkMTU: targetSize4 - 50, + replyLength: targetSize4 - 50, + }, + { + name: "IPv4 payload exceeds", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 + 10, + linkMTU: targetSize4, + replyLength: targetSize4, + }, + { + name: "IPv4 1 byte less", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: targetSize4 - replyHeaderLength4 - 1, + linkMTU: targetSize4, + replyLength: targetSize4 - 1, + }, + { + name: "IPv4 No payload", + srcAddress: remoteIPv4Addr, + injector: rxIPv4Bad, + checker: v4Checker, + payloadLength: 0, + linkMTU: targetSize4, + replyLength: replyHeaderLength4, + }, + { + name: "IPv6 exact match", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 larger MTU", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6 + 400, + replyLength: targetSize6, + }, + // NB. No "smaller MTU" test here as less than 1280 is not permitted + // in IPv6. + { + name: "IPv6 payload exceeds", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6, + linkMTU: targetSize6, + replyLength: targetSize6, + }, + { + name: "IPv6 1 byte less", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: targetSize6 - replyHeaderLength6 - 1, + linkMTU: targetSize6, + replyLength: targetSize6 - 1, + }, + { + name: "IPv6 no payload", + srcAddress: remoteIPv6Addr, + injector: rxIPv6Bad, + checker: v6Checker, + payloadLength: 0, + linkMTU: targetSize6, + replyLength: replyHeaderLength6, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU) + // Allocate and initialize the payload view. + payload := buffer.NewView(test.payloadLength) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + // Default routes for IPv4&6 so ICMP can find a route to the remote + // node when attempting to send the ICMP error Reply. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + v := test.injector(e, test.srcAddress, payload) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + if got, want := pkt.Pkt.Size(), test.replyLength; got != want { + t.Fatalf("got %d bytes of icmp error packet, want %d", got, want) + } + test.checker(t, pkt.Pkt, v) + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 32f53f217..330a7d170 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -8,6 +8,7 @@ go_library( "icmp.go", "igmp.go", "ipv4.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ @@ -49,3 +50,15 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":ipv4", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 8e392f86c..3d93a2cd0 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ package ipv4 import ( - "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -62,21 +61,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - received := stats.ICMP.V4.PacketsReceived + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { - received.Invalid.Increment() + received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because // we are the only handler, however other types do not cope well with @@ -106,19 +104,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, opts, op) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + tmp, optProblem := e.processIPOptions(pkt, opts, op) + if optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() } return } @@ -128,11 +121,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: - received.Echo.Increment() + received.echo.Increment() - sent := stats.ICMP.V4.PacketsSent + sent := e.stats.icmp.packetsSent if !e.protocol.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return } @@ -213,18 +206,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv4EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { @@ -243,31 +236,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } case header.ICMPv4SrcQuench: - received.SrcQuench.Increment() + received.srcQuench.Increment() case header.ICMPv4Redirect: - received.Redirect.Increment() + received.redirect.Increment() case header.ICMPv4TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv4ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv4Timestamp: - received.Timestamp.Increment() + received.timestamp.Increment() case header.ICMPv4TimestampReply: - received.TimestampReply.Increment() + received.timestampReply.Increment() case header.ICMPv4InfoRequest: - received.InfoRequest.Increment() + received.infoRequest.Increment() case header.ICMPv4InfoReply: - received.InfoReply.Increment() + received.infoReply.Increment() default: - received.Invalid.Increment() + received.invalid.Increment() } } @@ -317,7 +310,7 @@ func (*icmpReasonParamProblem) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv4(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -379,9 +372,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -435,13 +436,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. mtu := int(route.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize + const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize + if mtu > maxIPData { + mtu = maxIPData } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv4MinimumSize - if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize { return nil } @@ -464,36 +465,36 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize, Data: payload, }) icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonProtoUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonParamProblem: icmpHdr.SetType(header.ICMPv4ParamProblem) icmpHdr.SetCode(header.ICMPv4UnusedCode) icmpHdr.SetPointer(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -508,7 +509,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, icmpPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index d9b5fe6ed..4cd0b3256 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag // saying we were the last host to report is cleared, this action MAY be @@ -149,51 +149,49 @@ func (igmp *igmpState) init(ep *endpoint) { // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { - stats := igmp.ep.protocol.stack.Stats() - received := stats.IGMP.PacketsReceived + received := igmp.ep.stats.igmp.packetsReceived headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.IGMP(headerView) - // Temporarily reset the checksum field to 0 in order to calculate the proper - // checksum. - wantChecksum := h.Checksum() - h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(wantChecksum) - - if gotChecksum != wantChecksum { - received.ChecksumErrors.Increment() + // As per RFC 1071 section 1.3, + // + // To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF { + received.checksumErrors.Increment() return } switch h.Type() { case header.IGMPMembershipQuery: - received.MembershipQuery.Increment() + received.membershipQuery.Increment() if len(headerView) < header.IGMPQueryMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) case header.IGMPv1MembershipReport: - received.V1MembershipReport.Increment() + received.v1MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPv2MembershipReport: - received.V2MembershipReport.Increment() + received.v2MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPLeaveGroup: - received.LeaveGroup.Increment() + received.leaveGroup.Increment() // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or // Report, are ignored in all states" @@ -201,7 +199,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should // be silently ignored. New message types may be used by newer versions of // IGMP, by multicast routing protocols, or other uses." - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -244,7 +242,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { // writePacket assembles and sends an IGMP packet. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -272,18 +270,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip panic(fmt.Sprintf("failed to add IP header: %s", err)) } - sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.stats.igmp.packetsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sentStats.V1MembershipReport.Increment() + sentStats.v1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sentStats.V2MembershipReport.Increment() + sentStats.v2MembershipReport.Increment() case header.IGMPLeaveGroup: - sentStats.LeaveGroup.Increment() + sentStats.leaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } @@ -295,7 +293,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // messages. // // If the group already exists in the membership map, returns -// tcpip.ErrDuplicateAddress. +// *tcpip.ErrDuplicateAddress. // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { @@ -314,13 +312,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { // if required. // // Precondition: igmp.ep.mu must be locked. -func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of IGMP, but remains diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bb25a76fe..e5c80699d 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package ipv4 import ( - "errors" "fmt" "math" + "reflect" "sync/atomic" "time" @@ -73,6 +73,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + stats sharedStats // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -114,18 +115,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa e.mu.addressableEndpointState.Init(e) e.mu.igmp.init(e) e.mu.Unlock() + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4) + e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -193,7 +212,9 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } @@ -202,7 +223,9 @@ func (e *endpoint) disableLocked() { e.mu.igmp.softLeaveAll() // The address may have already been removed. - if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } @@ -237,7 +260,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error { hdrLen := header.IPv4MinimumSize var optLen int if options != nil { @@ -245,12 +268,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet } hdrLen += optLen if hdrLen > header.IPv4MaximumHeaderSize { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams @@ -275,7 +298,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -295,17 +318,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return err } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -334,7 +357,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -349,35 +372,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -385,6 +410,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return 0, err @@ -392,7 +419,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } @@ -400,7 +427,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -413,21 +440,21 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -451,36 +478,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err } n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } hdrLen := header.IPv4(h).HeaderLength() if hdrLen < header.IPv4MinimumSize { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } h, ok = pkt.Data.PullUp(int(hdrLen)) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv4(h) @@ -518,14 +545,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // wire format. We also want to check if the header's fields are valid before // sending the packet. if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv4(pkt.NetworkHeader().View()) ttl := h.TTL() if ttl == 0 { @@ -545,7 +572,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -577,19 +604,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -601,11 +630,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -631,7 +660,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -643,7 +672,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // be one of its own IP addresses (but not a broadcast or // multicast address). if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } // Make sure the source address is not a subnet-local broadcast address. @@ -651,7 +680,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { subnet := addressEndpoint.Subnet() addressEndpoint.DecRef() if subnet.IsBroadcast(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } } @@ -664,7 +693,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -674,9 +703,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.ip.IPTablesInputDropped.Increment() return } @@ -684,8 +714,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -698,8 +728,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } @@ -720,8 +750,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } if !ready { @@ -734,7 +764,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - stats.IP.PacketsDelivered.Increment() + stats.ip.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -755,19 +785,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) - if err != nil { - switch { - case - errors.Is(err, header.ErrIPv4OptDuplicate), - errors.Is(err, errIPv4RecordRouteOptInvalidPointer), - errors.Is(err, errIPv4RecordRouteOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidLength), - errors.Is(err, errIPv4TimestampOptInvalidPointer), - errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() } return } @@ -800,10 +824,12 @@ func (e *endpoint) Close() { e.disableLocked() e.mu.addressableEndpointState.Cleanup() + + e.protocol.forgetEndpoint(e.nic.ID()) } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() @@ -815,7 +841,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) @@ -872,7 +898,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -881,9 +907,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.igmp.joinGroup(addr) @@ -891,7 +917,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -900,7 +926,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.igmp.leaveGroup(addr) } @@ -911,6 +937,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.igmp.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -918,6 +949,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint + } + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -960,24 +999,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1023,9 +1062,9 @@ func (p *protocol) SetForwarding(v bool) { // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. -func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv4MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in @@ -1033,7 +1072,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro // The maximal internet header is 60 octets, and a typical internet header // is 20 octets, allowing a margin for headers of higher level protocols. if networkHeaderSize > header.IPv4MaximumHeaderSize { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU @@ -1095,6 +1134,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { options: opts, } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + p.mu.eps = make(map[tcpip.NICID]*endpoint) return p } } @@ -1192,16 +1232,9 @@ func (*optionUsageEcho) actions() optionActions { } } -var ( - errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") - errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") - errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") - errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") -) - // handleTimestamp does any required processing on a Timestamp option // in place. -func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) *header.IPv4OptParameterProblem { flags := tsOpt.Flags() var entrySize uint8 switch flags { @@ -1212,7 +1245,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres header.IPv4OptionTimestampWithPredefinedIPFlag: entrySize = header.IPv4OptionTimestampWithAddrSize default: - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } pointer := tsOpt.Pointer() @@ -1220,7 +1256,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // Since the pointer is 1 based, and the header is 4 bytes long the // pointer must point beyond the header therefore 4 or less is bad. if pointer <= header.IPv4OptionTimestampHdrLength { - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid @@ -1254,14 +1293,17 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres // timestamp, but the overflow count is incremented by one. if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { // By definition we have nothing to do. - return 0, nil + return nil } if tsOpt.IncOverflow() != 0 { - return 0, nil + return nil } // The overflow count is also full. - return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSOFLWAndFLGOffset, + NeedICMP: true, + } } if nextSlot+entrySize > dataLength { // The data area isn't full but there isn't room for a new entry. @@ -1280,32 +1322,36 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres if dataLength%entrySize != 0 { // The Data section size should be a multiple of the expected // timestamp entry size. - return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: false, + } } // If the size is OK, the pointer must be corrupted. } - return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptTSPointerOffset, + NeedICMP: true, + } } if usage.actions().timestamp == optionProcess { tsOpt.UpdateTimestamp(localAddress, clock) } - return 0, nil + return nil } -var ( - errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") - errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") -) - // handleRecordRoute checks and processes a Record route option. It is much // like the timestamp type 1 option, but without timestamps. The passed in // address is stored in the option in the correct spot if possible. -func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) *header.IPv4OptParameterProblem { optlen := rrOpt.Size() if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } pointer := rrOpt.Pointer() @@ -1315,7 +1361,10 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // Since the pointer is 1 based, and the header is 3 bytes long the // pointer must point beyond the header therefore 3 or less is bad. if pointer <= header.IPv4OptionRecordRouteHdrLength { - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } // RFC 791 page 21 says @@ -1332,7 +1381,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. if pointer > optlen { - return 0, nil + return nil } // The data area isn't full but there isn't room for a new entry. @@ -1357,17 +1406,23 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // } if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { // Length is bad, not on integral number of slots. - return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptionLengthOffset, + NeedICMP: true, + } } // If not length, the fault must be with the pointer. } - return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + return &header.IPv4OptParameterProblem{ + Pointer: header.IPv4OptRRPointerOffset, + NeedICMP: true, + } } if usage.actions().recordRoute == optionVerify { - return 0, nil + return nil } rrOpt.StoreAddress(localAddress) - return 0, nil + return nil } // processIPOptions parses the IPv4 options and produces a new set of options @@ -1378,8 +1433,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - The location of an error if there was one (or 0 if no error) // - If there is an error, information as to what it was was. // - The replacement option set. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - stats := e.protocol.stack.Stats() +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, *header.IPv4OptParameterProblem) { + stats := e.stats.ip opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1392,21 +1447,23 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt // This will need tweaking when we start really forwarding packets // as we may need to get two addresses, for rx and tx interfaces. // We will also have to take usage into account. - prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) localAddress := prefixedAddress.Address - if err != nil { + if !ok { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { - return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + return nil, &header.IPv4OptParameterProblem{ + NeedICMP: false, + } } localAddress = dstAddr } for { - option, done, err := optIter.Next() - if done || err != nil { - return optIter.ErrCursor, optIter.Finalize(), err + option, done, optProblem := optIter.Next() + if done || optProblem != nil { + return optIter.Finalize(), optProblem } optType := option.Type() if optType == header.IPv4OptionNOPType { @@ -1415,44 +1472,47 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } if optType == header.IPv4OptionListEndType { optIter.PushNOPOrEnd(optType) - return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + return optIter.Finalize(), nil } // check for repeating options (multiple NOPs are OK) if seenOptions[optType] { - return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + return nil, &header.IPv4OptParameterProblem{ + Pointer: optIter.ErrCursor, + NeedICMP: true, + } } seenOptions[optType] = true optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - stats.IP.OptionTSReceived.Increment() + stats.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } case *header.IPv4OptionRecordRoute: - stats.IP.OptionRRReceived.Increment() + stats.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) - offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) - if err != nil { - return optIter.ErrCursor + offset, nil, err + if optProblem := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage); optProblem != nil { + optProblem.Pointer += optIter.ErrCursor + return nil, optProblem } optIter.ConsumeBuffer(optLen) } default: - stats.IP.OptionUnknownReceived.Increment() + stats.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index a9e137c24..ed5899f0b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -78,8 +78,11 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + { + err := ep.Connect(randomAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) + } } // However, we can bind to a broadcast address to listen. @@ -270,6 +273,11 @@ func TestIPv4Sanity(t *testing.T) { nicID = 1 randomSequence = 123 randomIdent = 42 + // In some cases Linux sets the error pointer to the start of the option + // (offset 0) instead of the actual wrong value, which is the length byte + // (offset 1). For compatibility we must do the same. Use this constant + // to indicate where this happens. + pointerOffsetForInvalidLength = 0 ) var ( ipv4Addr = tcpip.AddressWithPrefix{ @@ -439,6 +447,21 @@ func TestIPv4Sanity(t *testing.T) { replyOptions: header.IPv4Options{}, }, { + name: "bad option - no length", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 1, 1, 1, 68, + // ^-start of timestamp.. but no length.. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + }, + { name: "bad option - length 0", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), @@ -448,7 +471,27 @@ func TestIPv4Sanity(t *testing.T) { // ^ 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, + }, + { + name: "bad option - length 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 1, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "bad option - length big", @@ -462,7 +505,11 @@ func TestIPv4Sanity(t *testing.T) { // space is not possible. (Second byte) 1, 2, 3, 4, }, - shouldFail: true, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { // This tests for some linux compatible behaviour. @@ -484,7 +531,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "multiple type 0 with room", @@ -589,7 +636,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, { name: "valid timestamp pointer", @@ -624,7 +671,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset, }, // End of option list with illegal option after it, which should be ignored. { @@ -636,24 +683,31 @@ func TestIPv4Sanity(t *testing.T) { 68, 12, 13, 0x11, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 10, 3, 99, + 0, 10, 3, 99, // EOL followed by junk }, replyOptions: header.IPv4Options{ 68, 12, 13, 0x21, 192, 168, 1, 12, 1, 2, 3, 4, - 0, 0, 0, 0, // 3 bytes unknown option - }, // ^ End of options hides following bytes. + 0, // End of Options hides following bytes. + 0, 0, 0, // 3 bytes unknown option removed. + }, }, { - // Timestamp with a size too small. + // Timestamp with a size much too small. name: "timestamp truncated", maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: header.IPv4Options{68, 1, 0, 0}, - // ^ Smallest possible is 8. - shouldFail: true, + options: header.IPv4Options{ + 68, 1, 0, 0, + // ^ Smallest possible is 8. Linux points at the 68. + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength, }, { name: "single record route with room", @@ -751,7 +805,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -769,7 +823,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { // Pointer must be 4 or more as it must point past the 3 byte header @@ -808,8 +862,7 @@ func TestIPv4Sanity(t *testing.T) { expectErrorICMP: true, ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, - paramProblemPointer: header.IPv4MinimumSize + 2, - replyOptions: header.IPv4Options{}, + paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset, }, { name: "duplicate record route", @@ -828,7 +881,6 @@ func TestIPv4Sanity(t *testing.T) { ICMPType: header.ICMPv4ParamProblem, ICMPCode: header.ICMPv4UnusedCode, paramProblemPointer: header.IPv4MinimumSize + 7, - replyOptions: header.IPv4Options{}, }, } @@ -884,7 +936,6 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } - ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, @@ -1328,8 +1379,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -1338,8 +1389,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -1348,8 +1399,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -1358,8 +1409,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag MTU smaller than header", @@ -1368,8 +1419,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize: 500, allowPackets: 0, outgoingErrors: 4, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than IPv4 minimum MTU", @@ -1379,7 +1430,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -1393,8 +1444,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -2427,8 +2478,9 @@ func TestReceiveFragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2506,11 +2558,11 @@ func TestWriteStats(t *testing.T) { // Parameterize the tests to run with both WritePacket and WritePackets. writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2522,7 +2574,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2532,7 +2584,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -2608,7 +2660,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go new file mode 100644 index 000000000..bee72c649 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats.go @@ -0,0 +1,190 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv4 protocol family. +type Stats struct { + // IP holds IPv4 statistics. + IP tcpip.IPStats + + // IGMP holds IGMP statistics. + IGMP tcpip.IGMPStats + + // ICMP holds ICMPv4 statistics. + ICMP tcpip.ICMPv4Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv4Stats + igmp multiCounterIGMPStats +} + +// LINT.IfChange(multiCounterICMPv4PacketStats) + +type multiCounterICMPv4PacketStats struct { + echo tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + srcQuench tcpip.MultiCounterStat + redirect tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + timestamp tcpip.MultiCounterStat + timestampReply tcpip.MultiCounterStat + infoRequest tcpip.MultiCounterStat + infoReply tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) { + m.echo.Init(a.Echo, b.Echo) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.srcQuench.Init(a.SrcQuench, b.SrcQuench) + m.redirect.Init(a.Redirect, b.Redirect) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.timestamp.Init(a.Timestamp, b.Timestamp) + m.timestampReply.Init(a.TimestampReply, b.TimestampReply) + m.infoRequest.Init(a.InfoRequest, b.InfoRequest) + m.infoReply.Init(a.InfoReply, b.InfoReply) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats) + +// LINT.IfChange(multiCounterICMPv4SentPacketStats) + +type multiCounterICMPv4SentPacketStats struct { + multiCounterICMPv4PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats) + +// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats) + +type multiCounterICMPv4ReceivedPacketStats struct { + multiCounterICMPv4PacketStats + invalid tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.invalid.Init(a.Invalid, b.Invalid) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv4Stats) + +type multiCounterICMPv4Stats struct { + packetsSent multiCounterICMPv4SentPacketStats + packetsReceived multiCounterICMPv4ReceivedPacketStats +} + +func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4Stats) + +// LINT.IfChange(multiCounterIGMPPacketStats) + +type multiCounterIGMPPacketStats struct { + membershipQuery tcpip.MultiCounterStat + v1MembershipReport tcpip.MultiCounterStat + v2MembershipReport tcpip.MultiCounterStat + leaveGroup tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) { + m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery) + m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport) + m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport) + m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup) +} + +// LINT.ThenChange(../../tcpip.go:IGMPPacketStats) + +// LINT.IfChange(multiCounterIGMPSentPacketStats) + +type multiCounterIGMPSentPacketStats struct { + multiCounterIGMPPacketStats + dropped tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.dropped.Init(a.Dropped, b.Dropped) +} + +// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats) + +// LINT.IfChange(multiCounterIGMPReceivedPacketStats) + +type multiCounterIGMPReceivedPacketStats struct { + multiCounterIGMPPacketStats + invalid tcpip.MultiCounterStat + checksumErrors tcpip.MultiCounterStat + unrecognized tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.invalid.Init(a.Invalid, b.Invalid) + m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) +} + +// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats) + +// LINT.IfChange(multiCounterIGMPStats) + +type multiCounterIGMPStats struct { + packetsSent multiCounterIGMPSentPacketStats + packetsReceived multiCounterIGMPReceivedPacketStats +} + +func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:IGMPStats) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go new file mode 100644 index 000000000..b28e7dcde --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) + } + + ep.Close() + + proto.mu.Lock() + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index afa45aefe..0c5f8d683 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,6 +10,7 @@ go_library( "ipv6.go", "mld.go", "ndp.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6ee162713..7298bd061 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -125,15 +125,14 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := e.protocol.stack.Stats().ICMP - sent := stats.V6.PacketsSent - received := stats.V6.PacketsReceived + sent := e.stats.icmp.packetsSent + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv6(v) @@ -147,7 +146,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -165,10 +164,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: - received.PacketTooBig.Increment() + received.packetTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) @@ -179,10 +178,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) @@ -194,9 +193,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6NeighborSolicit: - received.NeighborSolicit.Increment() + received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -210,7 +209,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast // address. if header.IsV6MulticastAddress(targetAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -238,7 +237,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } } @@ -263,13 +264,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { // Options are not valid as per the wire format, silently drop the // packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok = getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } } @@ -282,16 +283,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } } else if unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } else if e.nud != nil { e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -301,7 +302,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -379,15 +380,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.NeighborAdvert.Increment() + sent.neighborAdvert.Increment() case header.ICMPv6NeighborAdvert: - received.NeighborAdvert.Increment() + received.neighborAdvert.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -414,16 +415,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + return + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } - return } it, err := na.Options().Iter(false /* check */) if err != nil { // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -438,7 +441,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // DAD. targetLinkAddr, ok := getTargetLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() + return + } + + // As per RFC 4861 section 7.1.2: + // A node MUST silently discard any received Neighbor Advertisement + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP Destination Address is a multicast address the + // Solicited flag is zero. + if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() { + received.invalid.Increment() return } @@ -446,7 +460,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // address cache with the link address for the target of the message. if e.nud == nil { if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) } return } @@ -458,10 +472,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { }) case header.ICMPv6EchoRequest: - received.EchoRequest.Increment() + received.echoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -493,27 +507,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, }, replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv6EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv6ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv6RouterSolicit: - received.RouterSolicit.Increment() + received.routerSolicit.Increment() // // Validate the RS as per RFC 4861 section 6.1.1. @@ -521,7 +535,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Solictation? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -530,7 +544,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. - received.RouterOnlyPacketsDroppedByHost.Increment() + received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -540,13 +554,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -557,7 +571,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. if srcAddr == header.IPv6Any { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -569,7 +583,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + received.routerAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. @@ -577,7 +591,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Advertisement? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -586,7 +600,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -596,13 +610,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -638,26 +652,26 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // link-layer address be modified due to receiving one of the above // messages, the state SHOULD also be set to STALE to provide prompt // verification that the path to the new link-layer address is working." - received.RedirectMsg.Increment() + received.redirectMsg.Increment() if !isNDPValid() { - received.Invalid.Increment() + received.invalid.Increment() return } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: switch icmpType { case header.ICMPv6MulticastListenerQuery: - received.MulticastListenerQuery.Increment() + received.multicastListenerQuery.Increment() case header.ICMPv6MulticastListenerReport: - received.MulticastListenerReport.Increment() + received.multicastListenerReport.Increment() case header.ICMPv6MulticastListenerDone: - received.MulticastListenerDone.Increment() + received.multicastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -676,7 +690,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } default: - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -688,26 +702,39 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + if len(localAddr) == 0 { + addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + if addressEndpoint == nil { + return &tcpip.ErrNetworkUnreachable{} + } + + localAddr = addressEndpoint.AddressWithPrefix().Address + } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + return &tcpip.ErrBadLocalAddress{} } - defer r.Release() - r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -715,18 +742,23 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }, pkt); err != nil { - stat.Dropped.Increment() + }, header.IPv6ExtHdrSerializer{}); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } + + stat := netEP.stats.icmp.packetsSent + + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + stat.dropped.Increment() return err } - stat.NeighborSolicit.Increment() + stat.neighborSolicit.Increment() return nil } @@ -796,7 +828,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv6(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -863,10 +895,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - stats := p.stack.Stats().ICMP - sent := stats.V6.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return &tcpip.ErrNotConnected{} + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -897,11 +936,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi // the error message packet exceed the minimum IPv6 MTU // [IPv6]. mtu := int(route.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU + const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize + if mtu > maxIPv6Data { + mtu = maxIPv6Data } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize - available := int(mtu) - headerLen + available := mtu - header.ICMPv6ErrorHeaderSize if available < header.IPv6MinimumSize { return nil } @@ -915,31 +954,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, + ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize, Data: payload, }) newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonParameterProblem: icmpHdr.SetType(header.ICMPv6ParamProblem) icmpHdr.SetCode(reason.code) icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -953,7 +992,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, newPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index b1e6a70a2..db1c2e663 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "net" "reflect" @@ -22,6 +23,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -77,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return nil } @@ -91,16 +93,11 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } -type stubLinkAddressCache struct { - stack.LinkAddressCache -} +var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) -func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { - return 0 -} +type stubLinkAddressCache struct{} -func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { -} +func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} type stubNUDHandler struct { probeCount int @@ -148,15 +145,15 @@ func (*testInterface) Promiscuous() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var r stack.RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr @@ -643,7 +640,6 @@ func TestLinkResolution(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -653,8 +649,12 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) + { + var r bytes.Reader + r.Reset(hdr.View()) + if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -1283,7 +1283,7 @@ func TestLinkAddressRequest(t *testing.T) { localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectedErr *tcpip.Error + expectedErr tcpip.Error expectedRemoteAddr tcpip.Address expectedRemoteLinkAddr tcpip.LinkAddress }{ @@ -1321,79 +1321,80 @@ func TestLinkAddressRequest(t *testing.T) { name: "Unicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Multicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Unicast with no local address available", remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, { name: "Multicast with no local address available", remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, } for _, test := range tests { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } } - } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) - } + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + } - if test.expectedErr != nil { - return - } + if test.expectedErr != nil { + return + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - if pkt.Route.RemoteAddress != test.expectedRemoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) - } - if pkt.Route.LocalAddress != lladdr1 { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = test.expectedRemoteLinkAddr + if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedRemoteAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index ae4a8f508..94043ed4e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -20,6 +20,7 @@ import ( "fmt" "hash/fnv" "math" + "reflect" "sort" "sync/atomic" "time" @@ -177,6 +178,7 @@ type endpoint struct { dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -305,17 +307,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } if addressEndpoint.GetKind() != stack.PermanentTentative { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an @@ -367,14 +369,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) { } // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -416,7 +418,7 @@ func (e *endpoint) Enable() *tcpip.Error { // // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. - var err *tcpip.Error + var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { @@ -497,7 +499,9 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } @@ -553,11 +557,11 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { +func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ @@ -583,7 +587,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta // fragments left to be processed. The IP header must already be present in the // original packet. The transport header protocol number is required to avoid // parsing the IPv6 extension headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { networkHeader := header.IPv6(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are @@ -596,13 +600,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // of 8 as per RFC 8200 section 4.5: // Each complete fragment, except possibly the last ("rightmost") one, is // an integer multiple of 8 octets long. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) @@ -622,17 +626,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { return err } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -660,7 +664,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -675,36 +679,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } @@ -712,28 +717,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { return 0, err } networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt return nil }); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } // Remove the packet that was just fragmented and process the rest. @@ -743,19 +749,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -779,8 +785,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -788,17 +794,17 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv6(h) @@ -823,14 +829,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // sending the packet. proto, _, _, _, ok := parse.IPv6(pkt) if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv6(pkt.NetworkHeader().View()) hopLimit := h.HopLimit() if hopLimit <= 1 { @@ -852,7 +858,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -882,19 +888,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } // Loopback traffic skips the prerouting chain. if !e.nic.IsLoopback() { - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -906,11 +914,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats.ip h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } srcAddr := h.SourceAddress() @@ -920,7 +928,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. if header.IsV6MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.InvalidSourceAddressesReceived.Increment() return } @@ -930,7 +938,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.InvalidDestinationAddressesReceived.Increment() return } @@ -950,9 +958,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.IPTablesInputDropped.Increment() return } @@ -962,7 +971,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -986,7 +995,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1075,8 +1084,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } if done { @@ -1103,8 +1112,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } } @@ -1112,8 +1121,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1126,8 +1135,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // of the fragment, pointing to the Payload Length field of the // fragment packet. if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, @@ -1147,8 +1156,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // the fragment, pointing to the Fragment Offset field of the fragment // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, @@ -1173,8 +1182,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1194,7 +1203,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1244,12 +1253,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p e.handleICMP(pkt, hasFragmentHeader) } else { - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -1314,7 +1323,7 @@ func (e *endpoint) Close() { e.mu.addressableEndpointState.Cleanup() e.mu.Unlock() - e.protocol.forgetEndpoint(e) + e.protocol.forgetEndpoint(e.nic.ID()) } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -1323,7 +1332,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() @@ -1338,7 +1347,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err != nil { return nil, err @@ -1367,13 +1376,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return e.removePermanentEndpointLocked(addressEndpoint, true) @@ -1383,7 +1392,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() unicast := header.IsV6UnicastAddress(addr.Address) if unicast { @@ -1408,12 +1417,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) + err := e.leaveGroupLocked(snmc) // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { - return err + if _, ok := err.(*tcpip.ErrBadLocalAddress); ok { + err = nil } - - return nil + return err } // hasPermanentAddressLocked returns true if the endpoint has a permanent @@ -1623,7 +1632,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -1632,9 +1641,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.mld.joinGroup(addr) @@ -1642,7 +1651,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -1651,7 +1660,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.mld.leaveGroup(addr) } @@ -1662,6 +1671,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.mld.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1673,7 +1687,9 @@ type protocol struct { mu struct { sync.RWMutex - eps map[*endpoint]struct{} + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint } ids []uint32 @@ -1730,37 +1746,42 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e.mu.mld.init(e) e.mu.Unlock() + stackStats := p.stack.Stats() + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V6) + p.mu.Lock() defer p.mu.Unlock() - p.mu.eps[e] = struct{}{} + p.mu.eps[nic.ID()] = e return e } -func (p *protocol) forgetEndpoint(e *endpoint) { +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() - delete(p.mu.eps, e) + delete(p.mu.eps, nicID) } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1814,7 +1835,7 @@ func (p *protocol) SetForwarding(v bool) { return } - for ep := range p.mu.eps { + for _, ep := range p.mu.eps { ep.transitionForwarding(v) } } @@ -1823,9 +1844,9 @@ func (p *protocol) SetForwarding(v bool) { // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, // which includes the length of the extension headers. -func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv6MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 7112 section 5, we should discard packets if their IPv6 header @@ -1836,7 +1857,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro // bytes ensures that the header chain length does not exceed the IPv6 // minimum MTU. if networkHeadersLen > header.IPv6MinimumMTU { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU - uint32(networkHeadersLen) @@ -1906,7 +1927,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { hashIV: hashIV, } p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - p.mu.eps = make(map[*endpoint]struct{}) + p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) return p } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index b65c9d060..8248052a3 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "math" "net" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -370,12 +371,10 @@ func TestAddIpv6Address(t *testing.T) { t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != test.addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok { + t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber) + } else if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr) } }) } @@ -997,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { } // Should not have any more UDP packets. - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -1989,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2473,11 +2474,11 @@ func TestWriteStats(t *testing.T) { writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2489,7 +2490,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2499,7 +2500,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2574,7 +2575,7 @@ func (*limitedMatcher) Name() string { } // Match implements Matcher.Match. -func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) { if lm.limit == 0 { return true, false } @@ -2582,30 +2583,44 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, return false, false } +func knownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + func TestClearEndpointFromProtocolOnClose(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] - proto.mu.Unlock() - if !hasEP { - t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) - } + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + var nicIDs []tcpip.NICID + + proto.mu.Lock() + foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if !hasEndpointBeforeClose { + t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) } ep.Close() - { - proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] - proto.mu.Unlock() - if hasEP { - t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) - } + proto.mu.Lock() + _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] + nicIDs = knownNICIDs(proto) + proto.mu.Unlock() + if hasEndpointAfterClose { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) } } @@ -2819,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -2829,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -2839,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -2849,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than transport header", @@ -2860,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrMessageTooLong, + wantError: &tcpip.ErrMessageTooLong{}, }, { description: "Error when MTU is smaller than IPv6 minimum MTU", @@ -2870,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -2884,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -3053,3 +3068,22 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // supposed to be bound. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index ec54d88cc..2cc0dfebd 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } @@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // joinGroup handles joining a new group and sending and scheduling the required // messages. // -// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// If the group is already joined, returns *tcpip.ErrDuplicateAddress. // // Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { @@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { // required. // // Precondition: mld.ep.mu must be locked. -func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of MLD, but remains @@ -166,14 +166,14 @@ func (mld *mldState) sendQueuedReports() { // writePacket assembles and sends an MLD packet. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { - sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - var mldStat *tcpip.StatCounter +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) { + sentStats := mld.ep.stats.icmp.packetsSent + var mldStat tcpip.MultiCounterStat switch mldType { case header.ICMPv6MulticastListenerReport: - mldStat = sentStats.MulticastListenerReport + mldStat = sentStats.multicastListenerReport case header.ICMPv6MulticastListenerDone: - mldStat = sentStats.MulticastListenerDone + mldStat = sentStats.multicastListenerDone default: panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) } @@ -249,14 +249,14 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp Data: buffer.View(icmp).ToVectorisedView(), }) - if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, }, extensionHeaders); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } mldStat.Increment() diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 1d8fee50b..d7dde1767 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -241,7 +241,7 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered @@ -614,10 +614,10 @@ type slaacPrefixState struct { // tentative. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } if addressEndpoint.GetKind() != stack.PermanentTentative { @@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE dadDone := remaining == 0 - var err *tcpip.Error + var err tcpip.Error if !dadDone { err = ndp.sendDADPacket(addr, addressEndpoint) } @@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) @@ -731,8 +731,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add Data: buffer.View(icmp).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + sent := ndp.ep.stats.icmp.packetsSent + if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { @@ -740,10 +740,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } - sent.NeighborSolicit.Increment() + sent.neighborSolicit.Increment() + return nil } @@ -1855,20 +1856,20 @@ func (ndp *ndpState) startSolicitingRouters() { Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + sent := ndp.ep.stats.icmp.packetsSent + if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { - sent.RouterSolicit.Increment() + sent.routerSolicit.Increment() remaining-- } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index b1a5a5510..1d38b8b05 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -90,7 +90,7 @@ type testNDPDispatcher struct { addr tcpip.Address } -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { @@ -162,10 +162,15 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +type linkResolutionResult struct { + linkAddr tcpip.LinkAddress + ok bool +} + +// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. -func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -231,45 +236,40 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + wantInvalid := uint64(0) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) } } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) } + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } } -// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NS message with the Source Link Layer Address // option results in a new entry in the link address cache for the sender of // the message. -func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -382,7 +382,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } } -func TestNeighorSolicitationResponse(t *testing.T) { +func TestNeighborSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 remoteAddr := lladdr1 @@ -640,18 +640,12 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatal("expected an NDP NS response") } - if p.Route.LocalAddress != nicAddr { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } respNSDst := header.SolicitedNodeAddr(test.nsSrc) - if p.Route.RemoteAddress != respNSDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) - } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -727,10 +721,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { } } -// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a // valid NDP NA message with the Target Link Layer Address option results in a // new entry in the link address cache for the target of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -803,45 +797,40 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + wantInvalid := uint64(0) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) } } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) } + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } } -// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NA message with the Target Link Layer Address // option does not result in a new entry in the neighbor cache for the target // of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -1183,6 +1172,118 @@ func TestNDPValidation(t *testing.T) { } +// TestNeighborAdvertisementValidation tests that the NIC validates received +// Neighbor Advertisements. +// +// In particular, if the IP Destination Address is a multicast address, and the +// Solicited flag is not zero, the Neighbor Advertisement is invalid and should +// be discarded. +func TestNeighborAdvertisementValidation(t *testing.T) { + tests := []struct { + name string + ipDstAddr tcpip.Address + solicitedFlag bool + valid bool + }{ + { + name: "Multicast IP destination address with Solicited flag set", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: true, + valid: false, + }, + { + name: "Multicast IP destination address with Solicited flag unset", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: false, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag set", + ipDstAddr: lladdr0, + solicitedFlag: true, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag unset", + ipDstAddr: lladdr0, + solicitedFlag: false, + valid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetTargetAddress(lladdr1) + na.SetSolicitedFlag(test.solicitedFlag) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: test.ipDstAddr, + }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxNA := stats.NeighborAdvert + + if got := rxNA.Value(); got != 0 { + t.Fatalf("got rxNA = %d, want = 0", got) + } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + if got := rxNA.Value(); got != 1 { + t.Fatalf("got rxNA = %d, want = 1", got) + } + var wantInvalid uint64 = 1 + if test.valid { + wantInvalid = 0 + } + if got := invalid.Value(); got != wantInvalid { + t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) + } + // As per RFC 4861 section 7.2.5: + // When a valid Neighbor Advertisement is received ... + // If no entry exists, the advertisement SHOULD be silently discarded. + // There is no need to create an entry if none exists, since the + // recipient has apparently not initiated any communication with the + // target. + if neighbors, err := s.Neighbors(nicID); err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + }) + } +} + // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go new file mode 100644 index 000000000..0839be3cd --- /dev/null +++ b/pkg/tcpip/network/ipv6/stats.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv6 protocol family. +type Stats struct { + // IP holds IPv6 statistics. + IP tcpip.IPStats + + // ICMP holds ICMPv6 statistics. + ICMP tcpip.ICMPv6Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv6Stats +} + +// LINT.IfChange(multiCounterICMPv6PacketStats) + +type multiCounterICMPv6PacketStats struct { + echoRequest tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + packetTooBig tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + routerSolicit tcpip.MultiCounterStat + routerAdvert tcpip.MultiCounterStat + neighborSolicit tcpip.MultiCounterStat + neighborAdvert tcpip.MultiCounterStat + redirectMsg tcpip.MultiCounterStat + multicastListenerQuery tcpip.MultiCounterStat + multicastListenerReport tcpip.MultiCounterStat + multicastListenerDone tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) { + m.echoRequest.Init(a.EchoRequest, b.EchoRequest) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit) + m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert) + m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit) + m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert) + m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg) + m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery) + m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport) + m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats) + +// LINT.IfChange(multiCounterICMPv6SentPacketStats) + +type multiCounterICMPv6SentPacketStats struct { + multiCounterICMPv6PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats) + +// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats) + +type multiCounterICMPv6ReceivedPacketStats struct { + multiCounterICMPv6PacketStats + unrecognized tcpip.MultiCounterStat + invalid tcpip.MultiCounterStat + routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) + m.invalid.Init(a.Invalid, b.Invalid) + m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv6Stats) + +type multiCounterICMPv6Stats struct { + packetsSent multiCounterICMPv6SentPacketStats + packetsReceived multiCounterICMPv6ReceivedPacketStats +} + +func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6Stats) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index d0ffc299a..bd62c4482 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -6,8 +6,10 @@ go_library( name = "testutil", srcs = [ "testutil.go", + "testutil_unsafe.go", ], visibility = [ + "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 3af44991f..f5fa77b65 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -19,6 +19,8 @@ package testutil import ( "fmt" "math/rand" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -33,7 +35,7 @@ type MockLinkEndpoint struct { WrittenPackets []*stack.PacketBuffer mtu uint32 - err *tcpip.Error + err tcpip.Error allowPackets int } @@ -41,7 +43,7 @@ type MockLinkEndpoint struct { // // err is the error that will be returned once allowPackets packets are written // to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { +func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { return &MockLinkEndpoint{ mtu: mtu, err: err, @@ -62,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } // WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if ep.allowPackets == 0 { return ep.err } @@ -72,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip } // WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { var n int for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -127,3 +129,69 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go new file mode 100644 index 000000000..5ff764800 --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil_unsafe.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "reflect" + "unsafe" +) + +// unsafeExposeUnexportedFields takes a Value and returns a version of it in +// which even unexported fields can be read and written. +func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { + return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 2bad05a2e..57abec5c9 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -18,5 +18,6 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index d87193650..11dbdbbcf 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -329,7 +329,7 @@ func NewPortManager() *PortManager { // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { offset := uint32(rand.Int31n(numEphemeralPorts)) return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) } @@ -348,7 +348,7 @@ func (s *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) if err == nil { s.incPortHint() @@ -361,7 +361,7 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { for i := uint32(0); i < count; i++ { port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) @@ -374,7 +374,7 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } } - return 0, tcpip.ErrNoPortAvailable + return 0, &tcpip.ErrNoPortAvailable{} } // IsPortAvailable tests if the given port is available on all given protocols. @@ -404,7 +404,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // An optional testPort closure can be passed in which if provided will be used // to test if the picked port can be used. The function should return true if // the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -414,17 +414,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // protocols. if port != 0 { if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } if testPort != nil && !testPort(port) { s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } return port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { return false, nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 4bc949fd8..e70fbb72b 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -32,7 +33,7 @@ const ( type portReserveTestAction struct { port uint16 ip tcpip.Address - want *tcpip.Error + want tcpip.Error flags Flags release bool device tcpip.NICID @@ -50,19 +51,19 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress, want: nil}, {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, }, }, { tname: "bind to inaddr any", actions: []portReserveTestAction{ {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -80,8 +81,8 @@ func TestPortReservation(t *testing.T) { {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, @@ -91,14 +92,14 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, @@ -107,7 +108,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with device fails", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind to device", @@ -119,50 +120,50 @@ func TestPortReservation(t *testing.T) { tname: "bind to device and then without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "binding with reuseport and device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "mixing reuseport and not reuseport by binding to device", @@ -177,14 +178,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. @@ -195,7 +196,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with reuseport once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "release an unreserved device", @@ -213,16 +214,16 @@ func TestPortReservation(t *testing.T) { tname: "bind with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, }, }, { tname: "bind twice with reuseaddr once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport", @@ -236,14 +237,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport, and then reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", @@ -264,14 +265,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport, and then reuseaddr and reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", @@ -283,7 +284,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", @@ -295,7 +296,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind two tuples with reuseaddr", @@ -313,7 +314,7 @@ func TestPortReservation(t *testing.T) { tname: "bind wildcard, and then tuple with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard twice with reuseaddr", @@ -333,8 +334,8 @@ func TestPortReservation(t *testing.T) { continue } gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) + if diff := cmp.Diff(test.want, err); diff != "" { + t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -345,30 +346,29 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -378,49 +378,52 @@ func TestPickEphemeralPort(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPort(test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } } func TestPickEphemeralPortStable(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -430,20 +433,24 @@ func TestPickEphemeralPortStable(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPortStable(portOffset, test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index cf0a5fefe..db9b91815 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,7 +8,6 @@ go_binary( visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3b4f900e3..856ea998d 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -41,7 +41,7 @@ package main import ( - "bufio" + "bytes" "fmt" "log" "math/rand" @@ -51,7 +51,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" @@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { close(ch) }() - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return + var b bytes.Buffer + if err := func() error { + for { + if _, err := b.ReadFrom(os.Stdin); err != nil { + return fmt.Errorf("b.ReadFrom failed: %w", err) } - v.TrimFront(int(n)) + for b.Len() != 0 { + if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { + return fmt.Errorf("ep.Write failed: %s", err) + } + } } + }(); err != nil { + fmt.Println(err) } } @@ -179,7 +175,7 @@ func main() { waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventOut) terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { + if _, ok := terr.(*tcpip.ErrConnectStarted); ok { fmt.Println("Connect is pending...") <-notifyCh terr = ep.LastError() @@ -202,11 +198,11 @@ func main() { for { _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { break } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3ac562756..9b23df3a9 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,6 +20,7 @@ package main import ( + "bytes" "flag" "io" "log" @@ -50,7 +51,7 @@ type endpointWriter struct { } type tcpipError struct { - inner *tcpip.Error + inner tcpip.Error } func (e *tcpipError) Error() string { @@ -58,7 +59,9 @@ func (e *tcpipError) Error() string { } func (e *endpointWriter) Write(p []byte) (int, error) { - n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(p) + n, err := e.ep.Write(&r, tcpip.WriteOptions{}) if err != nil { return int(n), &tcpipError{ inner: err, @@ -86,7 +89,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { for { _, err := ep.Read(&w, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } @@ -214,7 +217,7 @@ func main() { for { n, wq, err := ep.Accept(nil) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index f3ad40fdf..019d6a63c 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,11 +15,16 @@ package tcpip import ( + "math" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" ) +// PacketOverheadFactor is used to multiply the value provided by the user on a +// SetSockOpt for setting the send/receive buffer sizes sockets. +const PacketOverheadFactor = 2 + // SocketOptionsHandler holds methods that help define endpoint specific // behavior for socket level socket options. These must be implemented by // endpoints to get notified when socket level options are set. @@ -41,13 +46,19 @@ type SocketOptionsHandler interface { OnCorkOptionSet(v bool) // LastError is invoked when SO_ERROR is read for an endpoint. - LastError() *Error + LastError() Error // UpdateLastError updates the endpoint specific last error field. - UpdateLastError(err *Error) + UpdateLastError(err Error) // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool + + // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. + GetSendBufferSize() (int64, Error) + + // IsUnixSocket is invoked to check if the socket is of unix domain. + IsUnixSocket() bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -72,18 +83,39 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} // LastError implements SocketOptionsHandler.LastError. -func (*DefaultSocketOptionsHandler) LastError() *Error { +func (*DefaultSocketOptionsHandler) LastError() Error { return nil } // UpdateLastError implements SocketOptionsHandler.UpdateLastError. -func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} +func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {} // HasNIC implements SocketOptionsHandler.HasNIC. func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } +// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. +func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) { + return 0, nil +} + +// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. +func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { + return false +} + +// StackHandler holds methods to access the stack options. These must be +// implemented by the stack. +type StackHandler interface { + // Option allows retrieving stack wide options. + Option(option interface{}) Error + + // TransportProtocolOption allows retrieving individual protocol level + // option values. + TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { type SocketOptions struct { handler SocketOptionsHandler + // StackHandler is initialized at the creation time and will not change. + stackHandler StackHandler `state:"manual"` + // These fields are accessed and modified using atomic operations. // broadcastEnabled determines whether datagram sockets are allowed to @@ -170,6 +205,14 @@ type SocketOptions struct { // bindToDevice determines the device to which the socket is bound. bindToDevice int32 + // getSendBufferLimits provides the handler to get the min, default and + // max size for send buffer. It is initialized at the creation time and + // will not change. + getSendBufferLimits GetSendBufferLimits `state:"manual"` + + // sendBufferSize determines the send buffer size for this socket. + sendBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -180,8 +223,10 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { so.handler = handler + so.stackHandler = stack + so.getSendBufferLimits = getSendBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -193,7 +238,7 @@ func storeAtomicBool(addr *uint32, v bool) { } // SetLastError sets the last error for a socket. -func (so *SocketOptions) SetLastError(err *Error) { +func (so *SocketOptions) SetLastError(err Error) { so.handler.UpdateLastError(err) } @@ -378,7 +423,7 @@ func (so *SocketOptions) SetRecvError(v bool) { } // GetLastError gets value for SO_ERROR option. -func (so *SocketOptions) GetLastError() *Error { +func (so *SocketOptions) GetLastError() Error { return so.handler.LastError() } @@ -435,7 +480,7 @@ type SockError struct { sockErrorEntry // Err is the error caused by the errant packet. - Err *Error + Err Error // ErrOrigin indicates the error origin. ErrOrigin SockErrOrigin // ErrType is the type in the ICMP header. @@ -493,7 +538,7 @@ func (so *SocketOptions) QueueErr(err *SockError) { } // QueueLocalErr queues a local error onto the local queue. -func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { +func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { so.QueueErr(&SockError{ Err: err, ErrOrigin: SockExtErrorOriginLocal, @@ -510,11 +555,52 @@ func (so *SocketOptions) GetBindToDevice() int32 { } // SetBindToDevice sets value for SO_BINDTODEVICE option. -func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { if !so.handler.HasNIC(bindToDevice) { - return ErrUnknownDevice + return &ErrUnknownDevice{} } atomic.StoreInt32(&so.bindToDevice, bindToDevice) return nil } + +// GetSendBufferSize gets value for SO_SNDBUF option. +func (so *SocketOptions) GetSendBufferSize() (int64, Error) { + if so.handler.IsUnixSocket() { + return so.handler.GetSendBufferSize() + } + return atomic.LoadInt64(&so.sendBufferSize), nil +} + +// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the +// stack handler should be invoked to set the send buffer size. +func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { + if so.handler.IsUnixSocket() { + return + } + + v := sendBufferSize + if notify { + // TODO(b/176170271): Notify waiters after size has grown. + // Make sure the send buffer size is within the min and max + // allowed. + ss := so.getSendBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } + + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + } + atomic.StoreInt64(&so.sendBufferSize, v) +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index cd423bf71..e5590ecc0 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) @@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr // AddAndAcquireTemporaryAddress adds a temporary address. // -// Returns tcpip.ErrDuplicateAddress if the address exists. +// Returns *tcpip.ErrDuplicateAddress if the address exists. // // The temporary address's endpoint is acquired and returned. -func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) @@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // If the addressable endpoint already has the address in a non-permanent state, // and addAndAcquireAddressLocked is adding a permanent address, that address is // promoted in place and its properties set to the properties provided. If the -// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// address already exists in any other state, then *tcpip.ErrDuplicateAddress is // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We are adding a non-permanent address but the address exists. No need // to go any further since we can only promote existing temporary/expired // addresses to permanent. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } addrState.mu.Lock() @@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState.mu.Unlock() // We are adding a permanent address but a permanent address already // exists. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } if addrState.mu.refs == 0 { @@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // RemovePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { a.mu.Lock() defer a.mu.Unlock() return a.removePermanentAddressLocked(addr) @@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error { addrState, ok := a.mu.endpoints[addr] if !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return a.removePermanentEndpointLocked(addrState) @@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre // RemovePermanentEndpoint removes the passed endpoint if it is associated with // a and permanent. -func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error { addrState, ok := ep.(*addressState) if !ok || addrState.addressableEndpointState != a { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } a.mu.Lock() @@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error { if !addrState.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } addrState.SetKind(PermanentExpired) @@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() { defer a.mu.Unlock() for _, ep := range a.mu.endpoints { - // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is // not a permanent address. - if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := a.removePermanentEndpointLocked(ep); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5e649cca6..54617f2e6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -198,15 +198,15 @@ type bucket struct { // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { +func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ @@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ conn, _ := ct.connForTID(tid) if conn == nil { // Not a tracked connection. - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } else if conn.manip == manipNone { // Unmanipulated connection. - return "", 0, tcpip.ErrInvalidOptionValue + return "", 0, &tcpip.ErrInvalidOptionValue{} } return conn.original.dstAddr, conn.original.dstPort, nil diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 4908848e9..63a42a2ea 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,8 @@ const ( protocolNumberOffset = 2 ) +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + // fwdTestNetworkEndpoint is a network-layer protocol endpoint. // Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only // use the first three: destination address, source address, and transport @@ -53,9 +55,7 @@ type fwdTestNetworkEndpoint struct { dispatcher TransportDispatcher } -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { +func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { return nil } @@ -104,7 +104,7 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -124,14 +124,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { // The network header should not already be populated. if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) @@ -141,6 +141,21 @@ func (f *fwdTestNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements stack.NetworkEndpoint. +func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { + return &fwdTestNetworkEndpointStats{} +} + +var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) + +type fwdTestNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) + // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { @@ -158,18 +173,15 @@ type fwdTestNetworkProtocol struct { } } -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) - -func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } -func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { +func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { +func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { return fwdTestNetDefaultPrefixLen } @@ -195,19 +207,19 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress return e } -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) @@ -307,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -323,7 +335,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.WritePacket(r, gso, protocol, pkt) @@ -356,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC UseNeighborCache: useNeighborCache, }) - if !useNeighborCache { - proto.addrCache = s.linkAddrCache - } - // Enable forwarding. s.SetForwarding(proto.Number(), true) @@ -389,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC t.Fatal("AddAddress #2 failed:", err) } + nic, ok := s.nics[2] + if !ok { + t.Fatal("NIC 2 does not exist") + } if useNeighborCache { // Control the neighbor cache for NIC 2. - nic, ok := s.nics[2] - if !ok { - t.Fatal("failed to get the neighbor cache for NIC 2") - } proto.neigh = nic.neigh + } else { + proto.addrCache = nic.linkAddrCache } // Route all packets to NIC 2. @@ -481,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -607,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") } }, }, @@ -692,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -768,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -858,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..63832c200 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { // ReplaceTable replaces or inserts table by name. It panics when an invalid id // is provided. -func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } @@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index b600a1cab..930b8f795 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,12 +24,16 @@ import ( const linkAddrCacheSize = 512 // max cache entries +var _ LinkAddressCache = (*linkAddrCache)(nil) + // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. // // This struct is safe for concurrent use. type linkAddrCache struct { + nic *NIC + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -41,9 +45,9 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - cache struct { + mu struct { sync.Mutex - table map[tcpip.FullAddress]*linkAddrEntry + table map[tcpip.Address]*linkAddrEntry lru linkAddrEntryList } } @@ -77,31 +81,42 @@ type linkAddrEntry struct { // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry - // TODO(gvisor.dev/issue/5150): move these fields under mu. - // mu protects the fields below. - mu sync.RWMutex + cache *linkAddrCache - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState + mu struct { + sync.RWMutex - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + addr tcpip.Address + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState - // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) + } } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - for _, callback := range e.onResolve { - callback(linkAddr, len(linkAddr) != 0) + res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} + for _, callback := range e.mu.onResolve { + callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) } } @@ -114,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { // // Precondition: e.mu must be locked func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.linkAddr) + if e.mu.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.mu.linkAddr) } - if expiration.IsZero() || expiration.After(e.expiration) { - e.expiration = expiration + if expiration.IsZero() || expiration.After(e.mu.expiration) { + e.mu.expiration = expiration } - e.s = ns + e.mu.s = ns } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { +func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. expiration := time.Now().Add(c.ageLimit) - c.cache.Lock() + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) - c.cache.Unlock() - entry.mu.Lock() defer entry.mu.Unlock() - entry.linkAddr = v + c.mu.Unlock() + + entry.mu.linkAddr = v entry.changeStateLocked(ready, expiration) } @@ -150,19 +165,19 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { - if entry, ok := c.cache.table[k]; ok { - c.cache.lru.Remove(entry) - c.cache.lru.PushFront(entry) +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { + if entry, ok := c.mu.table[k]; ok { + c.mu.lru.Remove(entry) + c.mu.lru.PushFront(entry) return entry } var entry *linkAddrEntry - if len(c.cache.table) == linkAddrCacheSize { - entry = c.cache.lru.Back() + if len(c.mu.table) == linkAddrCacheSize { + entry = c.mu.lru.Back() entry.mu.Lock() - delete(c.cache.table, entry.addr) - c.cache.lru.Remove(entry) + delete(c.mu.table, entry.mu.addr) + c.mu.lru.Remove(entry) // Wake waiters and mark the soon-to-be-reused entry as expired. entry.notifyCompletionLocked("" /* linkAddr */) @@ -172,53 +187,56 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } *entry = linkAddrEntry{ - addr: k, - s: incomplete, + cache: c, } - c.cache.table[k] = entry - c.cache.lru.PushFront(entry) + entry.mu.Lock() + entry.mu.addr = k + entry.mu.s = incomplete + entry.mu.Unlock() + c.mu.table[k] = entry + c.mu.lru.PushFront(entry) return entry } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - c.cache.Lock() - defer c.cache.Unlock() +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) entry.mu.Lock() defer entry.mu.Unlock() + c.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: - if !time.Now().After(entry.expiration) { + if !time.Now().After(entry.mu.expiration) { // Not expired. if onResolve != nil { - onResolve(entry.linkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) } - return entry.linkAddr, nil, nil + return entry.mu.linkAddr, nil, nil } entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { - entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + if entry.mu.done == nil { + entry.mu.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) + linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): @@ -234,10 +252,10 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has // succeeded and mark the entry accordingly. Returns true if request can stop, // false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { - c.cache.Lock() - defer c.cache.Unlock() - entry, ok := c.cache.table[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { + c.mu.Lock() + defer c.mu.Unlock() + entry, ok := c.mu.table[k] if !ok { // Entry was evicted from the cache. return true @@ -245,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: // Entry was made ready by resolver. case incomplete: @@ -255,19 +273,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att } // Max number of retries reached, delete entry. entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.cache.table, k) + delete(c.mu.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } return true } -func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { +func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { c := &linkAddrCache{ + nic: nic, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } - c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d7ac6cf5f..466a5e8d9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -26,7 +26,7 @@ import ( ) type testaddr struct { - addr tcpip.FullAddress + addr tcpip.Address linkAddr tcpip.LinkAddress } @@ -35,7 +35,7 @@ var testAddrs = func() []testaddr { for i := 0; i < 4*linkAddrCacheSize; i++ { addr := fmt.Sprintf("Addr%06d", i) addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + addr: tcpip.Address(addr), linkAddr: tcpip.LinkAddress("Link" + addr), }) } @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) + if ta.addr == addr { + r.cache.AddLinkAddress(ta.addr, ta.linkAddr) break } } @@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { got, ch, err := c.get(addr, linkRes, "", nil, nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { - return got, tcpip.ErrTimeout + return got, &tcpip.ErrTimeout{} } attemptedResolution = true <-ch @@ -93,17 +93,23 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } +func newEmptyNIC() *NIC { + n := &NIC{} + n.linkResQueue.init(n) + return n +} + func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. @@ -111,25 +117,25 @@ func TestCacheOverflow(t *testing.T) { e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. - c.cache.Lock() - defer c.cache.Unlock() + c.mu.Lock() + defer c.mu.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup @@ -137,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { wg.Add(1) go func() { for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) } wg.Done() }() @@ -150,52 +156,53 @@ func TestCacheConcurrent(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] - c.cache.Lock() - defer c.cache.Unlock() - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + c.mu.Lock() + defer c.mu.Unlock() + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) + _, _, err := c.get(e.addr, linkRes, "", nil, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.add(e.addr, l2) + c.AddLinkAddress(e.addr, l2) got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) } } @@ -206,15 +213,15 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } @@ -223,16 +230,16 @@ func TestCacheResolution(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} var requestCount uint32 @@ -244,17 +251,18 @@ func TestCacheResolutionFailed(t *testing.T) { e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) - e.addr.Addr += "2" - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + e.addr += "2" + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -265,11 +273,12 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 61636cae5..64383bc7c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -45,6 +45,8 @@ const ( linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + defaultPrefixLen = 128 + // Extra time to use when waiting for an async event to occur. defaultAsyncPositiveEventTimeout = 10 * time.Second @@ -102,7 +104,7 @@ type ndpDADEvent struct { nicID tcpip.NICID addr tcpip.Address resolved bool - err *tcpip.Error + err tcpip.Error } type ndpRouterEvent struct { @@ -172,7 +174,7 @@ type ndpDispatcher struct { } // Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, @@ -309,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) } @@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Should get the address immediately since we should not have performed @@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) { default: t.Fatal("expected DAD event") } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatal(err) } // We should not have sent any NDP NS messages. @@ -440,31 +442,31 @@ func TestDADResolve(t *testing.T) { NIC: nicID, }}) - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Make sure the address does not resolve before the resolution time has // passed. time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) } // Should not get a route even if we specify the local address as the // tentative address. { r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -472,8 +474,8 @@ func TestDADResolve(t *testing.T) { } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if addr.Address != addr1 { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Error(err) } // Should get a route using the address now that it is resolved. { @@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) { // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Receive a packet to simulate an address conflict. @@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Attempting to add the address again should not fail if the address's @@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } test.stopFn(t, s) @@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) { } if !test.skipFinalAddrCheck { - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } } @@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) { } // Add addresses for each NIC. - if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) } - if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) + addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) } expectDADEvent(nicID2, addr2) - if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) + addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) } expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Should get the address on NIC(2) and NIC(3) @@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) + if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatal(err) } - addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) + if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatal(err) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatal(err) } }) } @@ -2808,6 +2775,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2827,10 +2795,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN Gateway: llAddr3, NIC: nicID, }}) + if useNeighborCache { - s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } else { - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } return ndpDisp, e, s } @@ -2940,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3088,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3238,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) @@ -3255,8 +3222,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { defer ep.Close() ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } }) } @@ -3615,10 +3585,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index acee72572..88a3ff776 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() @@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(entry.neigh.LinkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) } return entry.neigh, nil, nil case Unknown, Incomplete, Failed: @@ -154,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, tcpip.ErrWouldBlock + return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // no matching entry for the remote address. } -// HandleUpperLevelConfirmation implements -// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in -// RFC 4861 section 7.3.1. -func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { +// handleUpperLevelConfirmation processes a confirmation of reachablity from +// some protocol that operates at a layer above the IP/link layer. +func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b96a56612..2870e4f66 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -194,7 +194,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -251,8 +251,8 @@ func TestNeighborCacheGetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -273,8 +273,8 @@ func TestNeighborCacheSetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -295,8 +295,9 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -321,11 +322,11 @@ func TestNeighborCacheEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { @@ -335,8 +336,8 @@ func TestNeighborCacheEntry(t *testing.T) { // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -359,8 +360,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -385,11 +387,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } neigh.removeEntry(entry.Addr) @@ -407,15 +409,18 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + { + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } } } @@ -462,8 +467,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -506,11 +512,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error { }) c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -531,15 +537,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } return nil @@ -579,8 +585,9 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -603,11 +610,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the entry @@ -626,11 +633,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -668,11 +675,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -681,8 +688,8 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -712,11 +719,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a duplicate entry with a different link address @@ -736,8 +743,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) } c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -774,11 +781,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -796,11 +803,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -830,8 +837,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -854,11 +862,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Override the entry with a static one using the same address @@ -886,11 +894,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -932,8 +940,8 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { LinkAddr: entry.LinkAddr, State: Static, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -948,11 +956,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } opts := overflowOptions{ @@ -989,8 +997,9 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1014,11 +1023,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a static entry. @@ -1037,11 +1046,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1072,8 +1081,8 @@ func TestNeighborCacheClear(t *testing.T) { } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1094,8 +1103,9 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1118,11 +1128,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Clear the cache. @@ -1140,11 +1150,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1188,16 +1198,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1225,11 +1232,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1247,16 +1254,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1299,11 +1303,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1331,15 +1335,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1366,8 +1370,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1398,8 +1404,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1423,16 +1429,13 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1455,8 +1458,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } @@ -1489,8 +1492,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Delay, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1507,8 +1510,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1539,16 +1542,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1567,8 +1567,8 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1576,19 +1576,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1627,19 +1621,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1674,19 +1662,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1713,13 +1695,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1772,16 +1754,13 @@ func BenchmarkCacheClear(b *testing.B) { b.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - b.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 75afb3001..53ac9bb6e 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -96,7 +96,7 @@ type neighborEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) isRouter bool job *tcpip.Job @@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} for _, callback := range e.onResolve { - callback(e.neigh.LinkAddr, succeeded) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index ec34ffa5a..140b8ca00 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -266,16 +266,16 @@ func TestEntryInitiallyUnknown(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -299,16 +299,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -333,10 +333,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -352,10 +352,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { } { nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -374,10 +374,10 @@ func TestEntryUnknownToStale(t *testing.T) { // No probes should have been sent. runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -392,8 +392,8 @@ func TestEntryUnknownToStale(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -427,11 +427,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -453,10 +453,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -483,8 +483,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -515,10 +515,10 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -553,8 +553,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -579,10 +579,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -620,8 +620,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -646,10 +646,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -684,8 +684,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -710,10 +710,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -744,8 +744,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -785,10 +785,10 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -812,8 +812,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -850,10 +850,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -903,8 +903,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -932,10 +932,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -977,8 +977,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1005,10 +1005,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1054,8 +1054,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -1083,10 +1083,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1134,8 +1134,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1157,10 +1157,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1212,8 +1212,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1235,10 +1235,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1290,8 +1290,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1313,10 +1313,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1358,8 +1358,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1381,10 +1381,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1439,8 +1439,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1462,10 +1462,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1520,8 +1520,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1543,10 +1543,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1601,8 +1601,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1624,10 +1624,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1678,8 +1678,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1701,10 +1701,10 @@ func TestEntryStaleToDelay(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1752,8 +1752,8 @@ func TestEntryStaleToDelay(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1780,10 +1780,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1851,8 +1851,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1880,10 +1880,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1958,8 +1958,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1987,10 +1987,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2065,8 +2065,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2088,10 +2088,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2147,8 +2147,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2170,10 +2170,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2231,8 +2231,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2254,10 +2254,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2319,8 +2319,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2343,11 +2343,11 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2372,10 +2372,10 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2418,8 +2418,8 @@ func TestEntryDelayToProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -2448,11 +2448,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2474,10 +2474,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2539,8 +2539,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2563,11 +2563,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2589,10 +2589,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2658,8 +2658,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2682,11 +2682,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2709,10 +2709,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2772,8 +2772,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2806,10 +2806,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2878,8 +2878,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2907,11 +2907,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2933,10 +2933,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3015,8 +3015,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3044,11 +3044,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3070,10 +3070,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3149,8 +3149,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3178,11 +3178,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3204,10 +3204,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3283,8 +3283,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3309,11 +3309,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3336,11 +3336,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) } e.mu.Lock() @@ -3406,8 +3406,8 @@ func TestEntryProbeToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3449,10 +3449,10 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -3498,8 +3498,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0f545f255..e56a624fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -53,6 +53,8 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution + linkAddrCache *linkAddrCache + mu struct { sync.RWMutex spoofing bool @@ -138,7 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.linkResQueue.init() + nic.linkResQueue.init(nic) + nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. @@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) } nic.LinkEndpoint.Attach(nic) @@ -228,7 +231,9 @@ func (n *NIC) disableLocked() { // // This matches linux's behaviour at the time of writing: // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { + switch err := n.clearNeighbors(); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) } @@ -243,7 +248,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() *tcpip.Error { +func (n *NIC) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -263,7 +268,7 @@ func (n *NIC) enable() *tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() *tcpip.Error { +func (n *NIC) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -299,48 +304,69 @@ func (n *NIC) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 7.2.2 (for IPv6): - // While waiting for address resolution to complete, the sender MUST, for - // each neighbor, retain a small queue of packets waiting for address - // resolution to complete. The queue MUST hold at least one packet, and MAY - // contain more. However, the number of queued packets per neighbor SHOULD - // be limited to some small value. When a queue overflows, the new arrival - // SHOULD replace the oldest entry. Once address resolution completes, the - // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, pkt) - return nil +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) + return err +} + +func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := n.writePacket(r, gso, protocol, pkt); err != nil { + return 0, err } - return err + return 1, nil + case *PacketBufferList: + return n.writePackets(r, gso, protocol, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } +} - return n.writePacket(r.Fields(), gso, protocol, pkt) +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + routeInfo, _, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + case *tcpip.ErrWouldBlock: + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and + // MAY contain more. However, the number of queued packets per neighbor + // SHOULD be limited to some small value. When a queue overflows, the new + // arrival SHOULD replace the oldest entry. Once address resolution + // completes, the node transmits any queued packets. + return n.linkResQueue.enqueue(r, gso, protocol, pkt) + default: + return 0, err + } } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { return err } @@ -351,10 +377,18 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +} + +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + } + + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { @@ -463,15 +497,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) @@ -535,63 +569,70 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { continue } - if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) { + case *tcpip.ErrBadLocalAddress: continue - } else { + default: return err } } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} +} + +func (n *NIC) confirmReachable(addr tcpip.Address) { + if n := n.neigh; n != nil { + n.handleUpperLevelConfirmation(addr) + } } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err } - return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve) + return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) } -func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { +func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { if n.neigh == nil { - return nil, tcpip.ErrNotSupported + return nil, &tcpip.ErrNotSupported{} } return n.neigh.entries(), nil } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.addStaticEntry(addr, linkAddress) return nil } -func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } if !n.neigh.removeEntry(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } return nil } -func (n *NIC) clearNeighbors() *tcpip.Error { +func (n *NIC) clearNeighbors() tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.clear() @@ -600,7 +641,7 @@ func (n *NIC) clearNeighbors() *tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -608,12 +649,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.JoinGroup(addr) @@ -621,15 +662,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.LeaveGroup(addr) @@ -879,9 +920,9 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { +func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { if n.neigh == nil { - return NUDConfigurations{}, tcpip.ErrNotSupported + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } return n.neigh.config(), nil } @@ -890,22 +931,22 @@ func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { +func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } c.resetInvalidFields() n.neigh.setConfig(c) return nil } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() eps, ok := n.mu.packetEPs[netProto] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.add(ep) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5b5c58afb..2f719fbe5 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -39,7 +39,7 @@ type testIPv6Endpoint struct { invalidatedRtr tcpip.Address } -func (*testIPv6Endpoint) Enable() *tcpip.Error { +func (*testIPv6Endpoint) Enable() tcpip.Error { return nil } @@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported + return 0, &tcpip.ErrNotSupported{} } // WriteHeaderIncludedPacket implements // NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } // HandlePacket implements NetworkEndpoint.HandlePacket. @@ -99,11 +99,20 @@ func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.invalidatedRtr = rtr } -var _ NetworkProtocol = (*testIPv6Protocol)(nil) +// Stats implements NetworkEndpoint. +func (*testIPv6Endpoint) Stats() NetworkEndpointStats { + return &testIPv6EndpointStats{} +} + +var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) + +type testIPv6EndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) -// An IPv6 NetworkProtocol that supports the bare minimum to make a stack -// believe it supports IPv6. -// // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -140,12 +149,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { return nil } @@ -160,15 +169,13 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // LinkAddressProtocol implements LinkAddressResolver. func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 12d67409a..77926e289 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -174,10 +174,6 @@ type NUDHandler interface { // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) - - // HandleUpperLevelConfirmation processes an incoming upper-level protocol - // (e.g. TCP acknowledgements) reachability confirmation. - HandleUpperLevelConfirmation(addr tcpip.Address) } // NUDConfigurations is the NUD configurations for the netstack. This is used diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 7bca1373e..ebfd5eb45 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -65,8 +65,9 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // No NIC with ID 1 yet. config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + err := s.SetNUDConfigurations(1, config) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) } } @@ -90,8 +91,9 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + _, err := s.NUDConfigurations(nicID) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) } } @@ -117,8 +119,9 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { } config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + err := s.SetNUDConfigurations(nicID, config) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) } } diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 41529ffd5..1c651e216 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,108 +28,205 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { - route *Route - proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + routeInfo RouteInfo + gso *GSO + proto tcpip.NetworkProtocolNumber + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - sync.Mutex + nic *NIC - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + mu struct { + sync.Mutex - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} + // The packets to send once the resolver completes. + // + // The link resolution channel is used as the key for this map. + packets map[<-chan struct{}][]pendingPacket -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + // + // cancelChans holds the same channels that are used as keys to packets. + cancelChans []<-chan struct{} + } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n) - packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets[0] = pendingPacket{} - packets = packets[1:] - p.route.Stats().IP.OutgoingPacketErrors.Increment() - p.route.Release() + if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } +} - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) +func (f *packetsPendingLinkResolution) init(nic *NIC) { + f.mu.Lock() + defer f.mu.Unlock() + f.nic = nic + f.mu.packets = make(map[<-chan struct{}][]pendingPacket) +} + +// dequeue any pending packets associated with ch. +// +// If success is true, packets will be written and sent to the given remote link +// address. +func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) { + f.mu.Lock() + packets, ok := f.mu.packets[ch] + delete(f.mu.packets, ch) + + if ok { + for i, cancelChan := range f.mu.cancelChans { + if cancelChan == ch { + f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...) + break + } + } } - f.packets[ch] = append(packets, pendingPacket{ - route: r, - proto: proto, - pkt: pkt, - }) + f.mu.Unlock() if ok { - return + f.dequeuePackets(packets, linkAddr, success) } +} - // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true - } +// enqueue a packet to be sent once link resolution completes. +// +// If the maximum number of pending resolutions is reached, the packets +// associated with the oldest link resolution will be dequeued as if they failed +// link resolution. +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + f.mu.Lock() + // Make sure we attempt resolution while holding f's lock so that we avoid + // a race where link resolution completes before we enqueue the packets. + // + // A @ T1: Call ResolvedFields (get link resolution channel) + // B @ T2: Complete link resolution, dequeue pending packets + // C @ T1: Enqueue packet that already completed link resolution (which will + // never dequeue) + // + // To make sure B does not interleave with A and C, we make sure A and C are + // done while holding the lock. + routeInfo, ch, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + // The route resolved immediately, so we don't need to wait for link + // resolution to send the packet. + f.mu.Unlock() + return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + case *tcpip.ErrWouldBlock: + // We need to wait for link resolution to complete. + default: + f.mu.Unlock() + return 0, err + } - f.Lock() - packets, ok := f.packets[ch] - delete(f.packets, ch) - f.Unlock() + defer f.mu.Unlock() - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } + packets, ok := f.mu.packets[ch] + packets = append(packets, pendingPacket{ + routeInfo: routeInfo, + gso: gso, + proto: proto, + pkt: pkt, + }) - for _, p := range packets { - if cancelled || p.route.IsResolutionRequired() { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + if len(packets) > maxPendingPacketsPerResolution { + f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) + packets[0] = pendingPacket{} + packets = packets[1:] - if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) - } - } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) - } - p.route.Release() + if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution { + panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution)) } - }() + } + + f.mu.packets[ch] = packets + + if ok { + return pkt.len(), nil + } + + cancelledPackets := f.newCancelChannelLocked(ch) + + if len(cancelledPackets) != 0 { + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as handing link resolution failures may be a costly operation. + go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */) + } + + return pkt.len(), nil } -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans[0] = nil - f.cancelChans = f.cancelChans[1:] - close(ch) +// newCancelChannelLocked appends the link resolution channel to a FIFO. If the +// maximum number of pending resolutions is reached, the oldest channel will be +// removed and its associated pending packets will be returned. +func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket { + f.mu.cancelChans = append(f.mu.cancelChans, newCH) + if len(f.mu.cancelChans) <= maxPendingResolutions { + return nil } - if l := len(f.cancelChans); l >= maxPendingResolutions { + + ch := f.mu.cancelChans[0] + f.mu.cancelChans[0] = nil + f.mu.cancelChans = f.mu.cancelChans[1:] + if l := len(f.mu.cancelChans); l > maxPendingResolutions { panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) } - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch + packets, ok := f.mu.packets[ch] + if !ok { + panic("must have a packet queue for an uncancelled channel") + } + delete(f.mu.packets, ch) + + return packets +} + +func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) { + for _, p := range packets { + if success { + p.routeInfo.RemoteLinkAddress = linkAddr + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + } else { + f.incrementOutgoingPacketErrors(p.proto, p.pkt) + + if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok { + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } + } + } + } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 68c113b6a..510da8689 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -172,10 +172,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -184,7 +184,7 @@ type TransportProtocol interface { // ParsePorts returns the source and destination ports stored in a // packet of this protocol. - ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this // protocol that don't match any existing endpoint. For example, @@ -197,12 +197,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -289,10 +289,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - JoinGroup(group tcpip.Address) *tcpip.Error + JoinGroup(group tcpip.Address) tcpip.Error // LeaveGroup attempts to leave the specified group. - LeaveGroup(group tcpip.Address) *tcpip.Error + LeaveGroup(group tcpip.Address) tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -440,17 +440,17 @@ func (k AddressKind) IsPermanent() bool { type AddressableEndpoint interface { // AddAndAcquirePermanentAddress adds the passed permanent address. // - // Returns tcpip.ErrDuplicateAddress if the address exists. + // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. // - // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed // permanent address. - RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + RemovePermanentAddress(addr tcpip.Address) tcpip.Error // MainAddress returns the endpoint's primary permanent address. MainAddress() tcpip.AddressWithPrefix @@ -512,14 +512,14 @@ type NetworkInterface interface { Promiscuous() bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -529,7 +529,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -547,8 +547,8 @@ type NetworkEndpoint interface { // Must only be called when the stack is in a state that allows the endpoint // to send and receive packets. // - // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. - Enable() *tcpip.Error + // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() tcpip.Error // Enabled returns true if the endpoint is enabled. Enabled() bool @@ -574,16 +574,16 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. @@ -597,6 +597,26 @@ type NetworkEndpoint interface { // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for // this endpoint. NetworkProtocolNumber() tcpip.NetworkProtocolNumber + + // Stats returns a reference to the network endpoint stats. + Stats() NetworkEndpointStats +} + +// NetworkEndpointStats is the interface implemented by each network endpoint +// stats struct. +type NetworkEndpointStats interface { + // IsNetworkEndpointStats is an empty method to implement the + // NetworkEndpointStats marker interface. + IsNetworkEndpointStats() +} + +// IPNetworkEndpointStats is a NetworkEndpointStats that tracks IP-related +// statistics. +type IPNetworkEndpointStats interface { + NetworkEndpointStats + + // IPStats returns the IP statistics of a network endpoint. + IPStats() *tcpip.IPStats } // ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. @@ -634,12 +654,12 @@ type NetworkProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -776,7 +796,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -786,7 +806,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -801,7 +821,7 @@ type InjectableLinkEndpoint interface { // link. // // dest is used by endpoints with multiple raw destinations. - InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -813,7 +833,7 @@ type LinkAddressResolver interface { // // The request is sent from the passed network interface. If the interface // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -829,12 +849,8 @@ type LinkAddressResolver interface { // A LinkAddressCache caches link addresses. type LinkAddressCache interface { - // CheckLocalAddress determines if the given local address exists, and if it - // does not exist. - CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) } // RawFactory produces endpoints for writing various types of raw packets. @@ -842,11 +858,11 @@ type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can // be used to write arbitrary packets that include the network header. - NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewPacketEndpoint produces endpoints for reading and writing packets // that include network and (when cooked is false) link layer headers. - NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) } // GSOType is the type of GSO segments. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1ff7b3a37..4ae0f2a1a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -86,12 +86,21 @@ type RouteInfo struct { RemoteLinkAddress tcpip.LinkAddress } -// Fields returns a RouteInfo with all of r's exported fields. This allows -// callers to store the route's fields without retaining a reference to it. +// Fields returns a RouteInfo with all of the known values for the route's +// fields. +// +// If any fields are unknown (e.g. remote link address when it is waiting for +// link address resolution), they will be unset. func (r *Route) Fields() RouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fieldsLocked() +} + +func (r *Route) fieldsLocked() RouteInfo { return RouteInfo{ routeInfo: r.routeInfo, - RemoteLinkAddress: r.RemoteLinkAddress(), + RemoteLinkAddress: r.mu.remoteLinkAddress, } } @@ -306,32 +315,43 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. +// ResolvedFieldsResult is the result of a route resolution attempt. +type ResolvedFieldsResult struct { + RouteInfo RouteInfo + Success bool +} + +// ResolvedFields attempts to resolve the remote link address if it is not +// known. // -// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. -// waiting for ARP reply). If address resolution is required, a notification -// channel is also returned for the caller to block on. The channel is closed -// once address resolution is complete (successful or not). If a callback is -// provided, it will be called when address resolution is complete, regardless +// If a callback is provided, it will be called before ResolvedFields returns +// when address resolution is not required. If address resolution is required, +// the callback will be called once address resolution is complete, regardless // of success or failure. -func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { - r.mu.Lock() - - if !r.isResolutionRequiredRLocked() { - // Nothing to do if there is no cache (which does the resolution on cache miss) or - // link address is already known. - r.mu.Unlock() - return nil, nil - } - - // Increment the route's reference count because finishResolution retains a - // reference to the route and releases it when called. - r.acquireLocked() - r.mu.Unlock() +// +// Note, the route will not cache the remote link address when address +// resolution completes. +func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error { + _, _, err := r.resolvedFields(afterResolve) + return err +} - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress +// resolvedFields is like ResolvedFields but also returns a notification channel +// when address resolution is required. This channel will become readable once +// address resolution is complete. +// +// The route's fields will also be returned, regardless of whether address +// resolution is required or not. +func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) { + r.mu.RLock() + fields := r.fieldsLocked() + resolutionRequired := r.isResolutionRequiredRLocked() + r.mu.RUnlock() + if !resolutionRequired { + if afterResolve != nil { + afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true}) + } + return fields, nil, nil } // If specified, the local address used for link address resolution must be an @@ -341,18 +361,27 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } - finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { - if ok { - r.ResolveWith(linkAddress) - } + afterResolveFields := fields + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { if afterResolve != nil { - afterResolve() + if r.Success { + afterResolveFields.RemoteLinkAddress = r.LinkAddress + } + + afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success}) } - r.Release() + }) + if err == nil { + fields.RemoteLinkAddress = linkAddr } + return fields, ch, err +} - _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - return ch, err +func (r *Route) nextHop() tcpip.Address { + if len(r.NextHop) == 0 { + return r.RemoteAddress + } + return r.NextHop } // local returns true if the route is a local route. @@ -371,11 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { - return false - } - - return r.linkRes != nil + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -404,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) @@ -414,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) @@ -424,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) @@ -496,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } + +// ConfirmReachable informs the network/link layer that the neighbour used for +// the route is reachable. +// +// "Reachable" is defined as having full-duplex communication between the +// local and remote ends of the route. +func (r *Route) ConfirmReachable() { + r.outgoingNIC.confirmReachable(r.nextHop()) +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c0aec61a6..119c4c505 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -76,12 +76,16 @@ type TCPCubicState struct { // TCPRACKState is used to hold a copy of the internal RACK state when the // TCPProbeFunc is invoked. type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool + DSACKSeen bool + ReoWnd time.Duration + ReoWndIncr uint8 + ReoWndPersist int8 + RTTSeq seqnum.Value } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -382,8 +386,6 @@ type Stack struct { stats tcpip.Stats - linkAddrCache *linkAddrCache - mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -446,7 +448,7 @@ type Stack struct { // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. - sendBufferSize SendBufferSizeOption + sendBufferSize tcpip.SendBufferSizeOption // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. @@ -554,7 +556,7 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: @@ -572,11 +574,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } case header.IPv6AddressSize: if len(addr.Addr) == header.IPv4AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{} } } @@ -584,10 +586,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl case netProto == t.NetProto: case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { - return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{} } default: - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } return addr, netProto, nil @@ -636,7 +638,6 @@ func New(opts Options) *Stack { linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), @@ -649,7 +650,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -701,10 +702,10 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.SetOption(option) } @@ -718,10 +719,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.Option(option) } @@ -730,10 +731,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.SetOption(option) } @@ -745,10 +746,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.Option(option) } @@ -781,15 +782,15 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables packet forwarding between NICs for the // passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } forwardingProtocol.SetForwarding(enable) @@ -852,10 +853,10 @@ func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { } // NewEndpoint creates a new transport layer endpoint of the given protocol. -func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewEndpoint(network, waiterQueue) @@ -864,9 +865,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } if !associated { @@ -875,7 +876,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewRawEndpoint(network, waiterQueue) @@ -883,9 +884,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // NewPacketEndpoint creates a new packet endpoint listening for the given // netProto. -func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) @@ -916,20 +917,20 @@ type NICOptions struct { // NICs can be configured. // // LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() // Make sure id is unique. if _, ok := s.nics[id]; ok { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } // Make sure name is unique, unless unnamed. if opts.Name != "" { for _, n := range s.nics { if n.Name() == opts.Name { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } } } @@ -945,7 +946,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls // LinkEndpoint.Attach to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } @@ -963,26 +964,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. -func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.enable() } // DisableNIC disables the given NIC. -func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.disable() @@ -1003,7 +1004,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // RemoveNIC removes NIC and all related routes from the network stack. -func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1013,10 +1014,10 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { // removeNICLocked removes NIC and all related routes from the network stack. // // s.mu must be locked. -func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { +func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } delete(s.nics, id) @@ -1050,6 +1051,9 @@ type NICInfo struct { Stats NICStats + // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. + NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats + // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext @@ -1081,6 +1085,12 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } + + netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) + for proto, netEP := range nic.networkEndpoints { + netStats[proto] = netEP.Stats() + } + nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), @@ -1088,6 +1098,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, + NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } @@ -1111,13 +1122,13 @@ type NICStateFlags struct { } // AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } // AddAddressWithPrefix is the same as AddAddress, but allows you to specify // the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { ap := tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: addr, @@ -1127,16 +1138,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify // whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { netProto, ok := s.networkProtocols[protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ Protocol: protocol, @@ -1149,13 +1160,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt // AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows // you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addAddress(protocolAddress, peb) @@ -1163,7 +1174,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc // RemoveAddress removes an existing network-layer address from the specified // NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -1171,7 +1182,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return nic.removeAddress(addr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // AllAddresses returns a map of NICIDs to their protocol addresses (primary @@ -1189,19 +1200,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { // GetMainNICAddress returns the first non-deprecated primary address and prefix // for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns an error if +// a deprecated primary address and prefix will be returned. Returns false if // the NIC doesn't exist and an empty value if the NIC doesn't have a primary // address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + return tcpip.AddressWithPrefix{}, false } - return nic.primaryAddress(protocol), nil + return nic.primaryAddress(protocol), true } func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -1301,7 +1312,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1337,9 +1348,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1405,7 +1416,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if id == 0 { @@ -1425,12 +1436,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if header.IsV6LoopbackAddress(remoteAddr) { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1476,13 +1487,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto } // SetPromiscuousMode enables or disables promiscuous mode in the given NIC. -func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setPromiscuousMode(enable) @@ -1492,13 +1503,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error // SetSpoofing enables or disables address spoofing in the given NIC, allowing // endpoints to bind to any address in the NIC. -func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setSpoofing(enable) @@ -1506,17 +1517,27 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { return nil } -// AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.add(fullAddr, linkAddr) - // TODO: provide a way for a transport endpoint to receive a signal - // that AddLinkAddress for a particular address has been called. +// AddLinkAddress adds a link address for the neighbor on the specified NIC. +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[nicID] + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) + return nil +} + +// LinkResolutionResult is the result of a link address resolution attempt. +type LinkResolutionResult struct { + LinkAddress tcpip.LinkAddress + Success bool } -// GetLinkAddress finds the link address corresponding to a neighbor's address. -// -// Returns a link address for the remote address, if readily available. +// GetLinkAddress finds the link address corresponding to a network address. // // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. @@ -1525,53 +1546,56 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t // with a notification channel for the caller to block on. Triggers address // resolution asynchronously. // -// If onResolve is provided, it will be called either immediately, if -// resolution is not required, or when address resolution is complete, with -// the resolved link address and whether resolution succeeded. After any -// callbacks have been called, the returned notification channel is closed. +// onResolve will be called either immediately, if resolution is not required, +// or when address resolution is complete, with the resolved link address and +// whether resolution succeeded. // // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -// -// TODO(gvisor.dev/issue/5151): Don't return the link address. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return "", nil, tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } linkRes, ok := s.linkAddrResolvers[protocol] if !ok { - return "", nil, tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} + } + + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.neighbors() } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addStaticNeighbor(addr, linkAddr) @@ -1580,26 +1604,26 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.removeNeighbor(addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.clearNeighbors() @@ -1609,25 +1633,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CheckRegisterTransportEndpoint checks if an endpoint can be registered with // the stack transport dispatcher. -func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} s.cleanupEndpointsMu.Unlock() @@ -1652,13 +1676,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { +func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { s.demux.unregisterRawEndpoint(netProto, transProto, ep) } @@ -1762,7 +1786,7 @@ func (s *Stack) Resume() { // RegisterPacketEndpoint registers ep with the stack, causing it to receive // all traffic of the specified netProto on the given NIC. If nicID is 0, it // receives traffic from every NIC. -func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1781,7 +1805,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network // Capture on a specific device. nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } if err := nic.registerPacketEndpoint(netProto, ep); err != nil { return err @@ -1819,12 +1843,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacketToRemote writes a payload on the specified NIC using the provided // network protocol and remote link address. -func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } pkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(nic.MaxHeaderLength()), @@ -1889,37 +1913,37 @@ func (s *Stack) RemoveTCPProbe() { } // JoinGroup joins the given multicast group on the given NIC. -func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.joinGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // LeaveGroup leaves the given multicast group on the given NIC. -func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.leaveGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // IsInGroup returns true if the NIC with ID nicID has joined the multicast // group multicastAddr. -func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.isInGroup(multicastAddr), nil } - return false, tcpip.ErrUnknownNICID + return false, &tcpip.ErrUnknownNICID{} } // IPTables returns the stack's iptables. @@ -1959,26 +1983,26 @@ func (s *Stack) AllowICMPMessage() bool { // GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol // number installed on the specified NIC. -func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() nic, ok := s.nics[nicID] if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return NUDConfigurations{}, tcpip.ErrUnknownNICID + return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } return nic.nudConfigs() @@ -1988,13 +2012,13 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.setNUDConfigs(c) @@ -2036,7 +2060,7 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -2048,7 +2072,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres addressEndpoint.DecRef() return nic.getNetworkEndpoint(netProto), nil } - return nil, tcpip.ErrBadAddress + return nil, &tcpip.ErrBadAddress{} } // FindNICNameFromID returns the name of the NIC for the given NICID. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 0b093e6c5..8d9b20b7e 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -14,7 +14,9 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // MinBufferSize is the smallest size of a receive or send buffer. @@ -29,14 +31,6 @@ const ( DefaultMaxBufferSize = 4 << 20 // 4 MiB ) -// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to -// get/set the default, min and max send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max receive buffer sizes. type ReceiveBufferSizeOption struct { @@ -46,17 +40,17 @@ type ReceiveBufferSizeOption struct { } // SetOption allows setting stack wide options. -func (s *Stack) SetOption(option interface{}) *tcpip.Error { +func (s *Stack) SetOption(option interface{}) tcpip.Error { switch v := option.(type) { - case SendBufferSizeOption: + case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -68,11 +62,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -81,14 +75,14 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option allows retrieving stack wide options. -func (s *Stack) Option(option interface{}) *tcpip.Error { +func (s *Stack) Option(option interface{}) tcpip.Error { switch v := option.(type) { - case *SendBufferSizeOption: + case *tcpip.SendBufferSizeOption: s.mu.RLock() *v = s.sendBufferSize s.mu.RUnlock() @@ -101,6 +95,6 @@ func (s *Stack) Option(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7e935ddff..41f95811f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,6 +60,15 @@ const ( protocolNumberOffset = 2 ) +func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { + if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + } else if addr != want { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + } + return nil +} + // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and // received packets; the counts of all endpoints are aggregated in the protocol // descriptor. @@ -81,7 +90,7 @@ type fakeNetworkEndpoint struct { dispatcher stack.TransportDispatcher } -func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { +func (f *fakeNetworkEndpoint) Enable() tcpip.Error { f.mu.Lock() defer f.mu.Unlock() f.mu.enabled = true @@ -145,7 +154,7 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -153,7 +162,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -176,18 +185,30 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (f *fakeNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements NetworkEndpoint. +func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { + return &fakeNetworkEndpointStats{} +} + +var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) + +type fakeNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} + // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. @@ -202,15 +223,15 @@ type fakeNetworkProtocol struct { } } -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fakeNetNumber } -func (f *fakeNetworkProtocol) MinimumPacketSize() int { +func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (f *fakeNetworkProtocol) DefaultPrefixLen() int { +func (*fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } @@ -232,23 +253,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li return e } -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: f.defaultTTL = uint8(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -397,7 +418,7 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return err @@ -406,7 +427,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r *stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -435,14 +456,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -579,8 +600,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } } @@ -628,8 +649,9 @@ func TestDisableUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.DisableNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -687,8 +709,9 @@ func TestRemoveUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.RemoveNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -731,8 +754,8 @@ func TestRemoveNIC(t *testing.T) { func TestRouteWithDownNIC(t *testing.T) { tests := []struct { name string - downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error }{ { name: "Disabled NIC", @@ -890,15 +913,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -911,7 +934,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1036,11 +1059,12 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) } } @@ -1087,12 +1111,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + { + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) + } } } @@ -1186,7 +1213,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1214,7 +1241,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1253,8 +1280,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1290,7 +1317,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1333,8 +1360,8 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } // Set promiscuous mode to false, then check that packet can't be @@ -1540,7 +1567,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1590,8 +1617,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} @@ -1610,8 +1640,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } } @@ -1753,9 +1786,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { anyAddr = header.IPv6Any } - want := tcpip.ErrNetworkUnreachable + var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} if tc.routeNeeded { - want = tcpip.ErrNoRoute + want = &tcpip.ErrNoRoute{} } // If there is no endpoint, it won't work. @@ -1769,8 +1802,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) } } else { if err != nil { @@ -1861,20 +1894,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) + gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1915,12 +1948,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { @@ -1928,12 +1957,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2102,7 +2127,7 @@ func TestCreateNICWithOptions(t *testing.T) { type callArgsAndExpect struct { nicID tcpip.NICID opts stack.NICOptions - err *tcpip.Error + err tcpip.Error } tests := []struct { @@ -2120,7 +2145,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{Name: "eth2"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2135,7 +2160,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(2), opts: stack.NICOptions{Name: "lo"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2165,7 +2190,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2474,12 +2499,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Check that we get no address after removal. + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } - if gotMainAddr != expectedMainAddr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + t.Fatal(err) } }) } @@ -2525,12 +2550,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2561,12 +2582,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2584,12 +2601,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { + t.Fatal(err) } } @@ -2621,17 +2634,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { t.Fatal("AddAddressWithOptions failed:", err) } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + addr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } { @@ -2710,18 +2723,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.RemoveAddress(1, "\x03"); err != nil { t.Fatalf("RemoveAddress failed: %v", err) } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatalf("s.GetMainNICAddress failed: %v", err) + addr, ok = s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else { if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } } }) @@ -3247,12 +3259,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should be tentative so it should not be a main address. - got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Enabling the NIC should start DAD for the address. @@ -3264,12 +3272,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -3284,12 +3288,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } // Enabling the NIC again should be a no-op. @@ -3299,12 +3299,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3313,14 +3309,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { testCases := []struct { name string rs stack.ReceiveBufferSizeOption - err *tcpip.Error + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, @@ -3352,31 +3348,32 @@ func TestStackSendBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - ss stack.SendBufferSizeOption - err *tcpip.Error + ss tcpip.SendBufferSizeOption + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{}) defer s.Close() - if err := s.SetOption(tc.ss); err != tc.err { - t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + err := s.SetOption(tc.ss) + if diff := cmp.Diff(tc.err, err); diff != "" { + t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) } - var ss stack.SendBufferSizeOption if tc.err == nil { + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) } @@ -3778,20 +3775,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } // Should still get the address when the NIC is diabled. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("DisableNIC(%d): %s", nicID, err) } - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3939,7 +3932,7 @@ func TestFindRouteWithForwarding(t *testing.T) { addrNIC tcpip.NICID localAddr tcpip.Address - findRouteErr *tcpip.Error + findRouteErr tcpip.Error dependentOnForwarding bool }{ { @@ -3948,7 +3941,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3957,7 +3950,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3966,7 +3959,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4002,7 +3995,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4011,7 +4004,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4035,7 +4028,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4051,7 +4044,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4059,7 +4052,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4067,7 +4060,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4075,7 +4068,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4107,7 +4100,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4115,7 +4108,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4123,7 +4116,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4131,7 +4124,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4186,8 +4179,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if r != nil { defer r.Release() } - if err != test.findRouteErr { - t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + if diff := cmp.Diff(test.findRouteErr, err); diff != "" { + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { @@ -4234,8 +4227,11 @@ func TestFindRouteWithForwarding(t *testing.T) { if err := s.SetForwarding(test.netCfg.proto, false); err != nil { t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) } - if err := send(r, data); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + { + err := send(r, data) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) + } } if n := ep1.Drain(); n != 0 { t.Errorf("got %d unexpected packets from ep1", n) @@ -4297,8 +4293,9 @@ func TestWritePacketToRemote(t *testing.T) { } t.Run("InvalidNICID", func(t *testing.T) { - if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) + if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) } pkt, ok := e.Read() if got, want := ok, false; got != want { @@ -4372,10 +4369,64 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + { + err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) + } + } + { + err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) + } + } +} + +func TestStaticGetLinkAddress(t *testing.T) { + const ( + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4", + proto: ipv4.ProtocolNumber, + addr: header.IPv4Broadcast, + expectedLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "IPv6", + proto: ipv6.ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, } - if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch := make(chan stack.LinkResolutionResult, 1) + if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != nil { + t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) + } + + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 07b2818d2..26eceb804 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -205,7 +205,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -222,7 +222,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -294,7 +294,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for i, n := range netProtos { if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) @@ -306,7 +306,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // checkEndpoint checks if an endpoint can be registered with the dispatcher. -func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for _, n := range netProtos { if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { return err @@ -403,7 +403,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() @@ -412,7 +412,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -422,7 +422,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } -func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { ep.mu.RLock() defer ep.mu.RUnlock() @@ -431,7 +431,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -456,7 +456,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -464,7 +464,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.Lock() @@ -482,7 +482,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } -func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -490,7 +490,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.RLock() @@ -649,10 +649,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.mu.Lock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 57e1f8354..10cbbe589 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) { for _, test := range []struct { name string proto tcpip.NetworkProtocolNumber - want *tcpip.Error + want tcpip.Error }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, {"success", ipv4.ProtocolNumber, nil}, } { t.Run(test.name, func(t *testing.T) { @@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) @@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) { defer wq.EventUnregister(&we) defer close(ch) - var err *tcpip.Error + var err tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9d39533a1..cf5de747b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "io" "testing" @@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} - ep.ops.InitHandler(ep) +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) return ep } @@ -86,19 +87,20 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, Data: buffer.View(v).ToVectorisedView(), @@ -112,42 +114,42 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + return -1, &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeTransportEndpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { f.peerAddr = addr.Addr // Find the route. r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { r.Release() return err @@ -162,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { return nil } -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return nil } func (*fakeTransportEndpoint) Reset() { } -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { +func (*fakeTransportEndpoint) Listen(int) tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -186,9 +188,8 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai return a, nil, nil } -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { if err := f.proto.stack.RegisterTransportEndpoint( - a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, @@ -202,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { return nil } -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } @@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress, route: route, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } @@ -251,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {} func (*fakeTransportEndpoint) Wait() {} -func (*fakeTransportEndpoint) LastError() *tcpip.Error { +func (*fakeTransportEndpoint) LastError() tcpip.Error { return nil } @@ -279,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack), nil } -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return nil, &tcpip.ErrUnknownProtocol{} } func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { return 0, 0, nil } @@ -299,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: f.opts.good = bool(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) { } // Create buffer that will hold the payload. - view := buffer.NewView(30) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + b := make([]byte, 30) + var r bytes.Reader + r.Reset(b) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 56aac093c..c500a0d1c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -46,141 +47,6 @@ import ( // Using header.IPv4AddressSize would cause an import cycle. const ipv4AddressSize = 4 -// Error represents an error in the netstack error space. Using a special type -// ensures that errors outside of this space are not accidentally introduced. -// -// All errors must have unique msg strings. -// -// +stateify savable -type Error struct { - msg string - - ignoreStats bool -} - -// String implements fmt.Stringer.String. -func (e *Error) String() string { - if e == nil { - return "<nil>" - } - return e.msg -} - -// IgnoreStats indicates whether this error type should be included in failure -// counts in tcpip.Stats structs. -func (e *Error) IgnoreStats() bool { - return e.ignoreStats -} - -// Errors that can be returned by the network stack. -var ( - ErrUnknownProtocol = &Error{msg: "unknown protocol"} - ErrUnknownNICID = &Error{msg: "unknown nic id"} - ErrUnknownDevice = &Error{msg: "unknown device"} - ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} - ErrDuplicateNICID = &Error{msg: "duplicate nic id"} - ErrDuplicateAddress = &Error{msg: "duplicate address"} - ErrNoRoute = &Error{msg: "no route"} - ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} - ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} - ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} - ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} - ErrNoPortAvailable = &Error{msg: "no ports are available"} - ErrPortInUse = &Error{msg: "port is in use"} - ErrBadLocalAddress = &Error{msg: "bad local address"} - ErrClosedForSend = &Error{msg: "endpoint is closed for send"} - ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} - ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} - ErrConnectionRefused = &Error{msg: "connection was refused"} - ErrTimeout = &Error{msg: "operation timed out"} - ErrAborted = &Error{msg: "operation aborted"} - ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} - ErrDestinationRequired = &Error{msg: "destination address is required"} - ErrNotSupported = &Error{msg: "operation not supported"} - ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} - ErrNotConnected = &Error{msg: "endpoint not connected"} - ErrConnectionReset = &Error{msg: "connection reset by peer"} - ErrConnectionAborted = &Error{msg: "connection aborted"} - ErrNoSuchFile = &Error{msg: "no such file"} - ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrBadAddress = &Error{msg: "bad address"} - ErrNetworkUnreachable = &Error{msg: "network is unreachable"} - ErrMessageTooLong = &Error{msg: "message too long"} - ErrNoBufferSpace = &Error{msg: "no buffer space available"} - ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} - ErrNotPermitted = &Error{msg: "operation not permitted"} - ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} - ErrMalformedHeader = &Error{msg: "header is malformed"} - ErrBadBuffer = &Error{msg: "bad buffer"} -) - -var messageToError map[string]*Error - -var populate sync.Once - -// StringToError converts an error message to the error. -func StringToError(s string) *Error { - populate.Do(func() { - var errors = []*Error{ - ErrUnknownProtocol, - ErrUnknownNICID, - ErrUnknownDevice, - ErrUnknownProtocolOption, - ErrDuplicateNICID, - ErrDuplicateAddress, - ErrNoRoute, - ErrBadLinkEndpoint, - ErrAlreadyBound, - ErrInvalidEndpointState, - ErrAlreadyConnecting, - ErrAlreadyConnected, - ErrNoPortAvailable, - ErrPortInUse, - ErrBadLocalAddress, - ErrClosedForSend, - ErrClosedForReceive, - ErrWouldBlock, - ErrConnectionRefused, - ErrTimeout, - ErrAborted, - ErrConnectStarted, - ErrDestinationRequired, - ErrNotSupported, - ErrQueueSizeNotSupported, - ErrNotConnected, - ErrConnectionReset, - ErrConnectionAborted, - ErrNoSuchFile, - ErrInvalidOptionValue, - ErrBadAddress, - ErrNetworkUnreachable, - ErrMessageTooLong, - ErrNoBufferSpace, - ErrBroadcastDisabled, - ErrNotPermitted, - ErrAddressFamilyNotSupported, - ErrMalformedHeader, - ErrBadBuffer, - } - - messageToError = make(map[string]*Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e -} - // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -194,7 +60,7 @@ type ErrSaveRejection struct { } // Error returns a sensible description of the save rejection error. -func (e ErrSaveRejection) Error() string { +func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } @@ -471,30 +337,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) @@ -647,7 +498,7 @@ type Endpoint interface { // If non-zero number of bytes are successfully read and written to dst, err // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer // should be returned. - Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error) + Read(io.Writer, ReadOptions) (ReadResult, Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -662,7 +513,7 @@ type Endpoint interface { // stream (TCP) Endpoints may return partial writes, and even then only // in the case where writing additional data would block. Other Endpoints // will either write the entire message or return an error. - Write(Payloader, WriteOptions) (int64, *Error) + Write(Payloader, WriteOptions) (int64, Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -676,21 +527,21 @@ type Endpoint interface { // connected returns nil. Calling connect again results in ErrAlreadyConnected. // Anything else -- the attempt to connect failed. // - // If address.Addr is empty, this means that Enpoint has to be + // If address.Addr is empty, this means that Endpoint has to be // disconnected if this is supported, otherwise // ErrAddressFamilyNotSupported must be returned. - Connect(address FullAddress) *Error + Connect(address FullAddress) Error // Disconnect disconnects the endpoint from its peer. - Disconnect() *Error + Disconnect() Error // Shutdown closes the read and/or write end of the endpoint connection // to its peer. - Shutdown(flags ShutdownFlags) *Error + Shutdown(flags ShutdownFlags) Error // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. - Listen(backlog int) *Error + Listen(backlog int) Error // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. This method does not @@ -700,36 +551,36 @@ type Endpoint interface { // // If peerAddr is not nil then it is populated with the peer address of the // returned endpoint. - Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. - Bind(address FullAddress) *Error + Bind(address FullAddress) Error // GetLocalAddress returns the address to which the endpoint is bound. - GetLocalAddress() (FullAddress, *Error) + GetLocalAddress() (FullAddress, Error) // GetRemoteAddress returns the address to which the endpoint is // connected. - GetRemoteAddress() (FullAddress, *Error) + GetRemoteAddress() (FullAddress, Error) // Readiness returns the current readiness of the endpoint. For example, // if waiter.EventIn is set, the endpoint is immediately readable. Readiness(mask waiter.EventMask) waiter.EventMask // SetSockOpt sets a socket option. - SetSockOpt(opt SettableSocketOption) *Error + SetSockOpt(opt SettableSocketOption) Error // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOptInt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) Error // GetSockOpt gets a socket option. - GetSockOpt(opt GettableSocketOption) *Error + GetSockOpt(opt GettableSocketOption) Error // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOptInt) (int, *Error) + GetSockOptInt(SockOptInt) (int, Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -752,7 +603,7 @@ type Endpoint interface { SetOwner(owner PacketOwner) // LastError clears and returns the last error reported by the endpoint. - LastError() *Error + LastError() Error // SocketOptions returns the structure which contains all the socket // level options. @@ -840,10 +691,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the send buffer size option. - SendBufferSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the receive buffer size option. ReceiveBufferSizeOption @@ -1011,12 +858,54 @@ type SettableSocketOption interface { isSettableSocketOption() } +// CongestionControlState indicates the current congestion control state for +// TCP sender. +type CongestionControlState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open CongestionControlState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. type TCPInfoOption struct { - RTT time.Duration + // RTT is the smoothed round trip time. + RTT time.Duration + + // RTTVar is the round trip time variation. RTTVar time.Duration + + // RTO is the retransmission timeout for the endpoint. + RTO time.Duration + + // CcState is the congestion control state. + CcState CongestionControlState + + // SndCwnd is the congestion window, in packets. + SndCwnd uint32 + + // SndSsthresh is the threshold between slow start and congestion + // avoidance. + SndSsthresh uint32 + + // ReorderSeen indicates if reordering is seen in the endpoint. + ReorderSeen bool } func (*TCPInfoOption) isGettableSocketOption() {} @@ -1248,6 +1137,31 @@ type IPPacketInfo struct { DestinationAddr Address } +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + +// GetSendBufferLimits is used to get the send buffer size limits. +type GetSendBufferLimits func(StackHandler) SendBufferSizeOption + +// GetStackSendBufferLimits is used to get default, min and max send buffer size. +func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { + var ss SendBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1317,8 +1231,33 @@ func (s *StatCounter) String() string { return strconv.FormatUint(s.Value(), 10) } +// A MultiCounterStat keeps track of two counters at once. +type MultiCounterStat struct { + a, b *StatCounter +} + +// Init sets both internal counters to point to a and b. +func (m *MultiCounterStat) Init(a, b *StatCounter) { + m.a = a + m.b = b +} + +// Increment adds one to the counters. +func (m *MultiCounterStat) Increment() { + m.a.Increment() + m.b.Increment() +} + +// IncrementBy increments the counters by v. +func (m *MultiCounterStat) IncrementBy(v uint64) { + m.a.IncrementBy(v) + m.b.IncrementBy(v) +} + // ICMPv4PacketStats enumerates counts for all ICMPv4 packet types. type ICMPv4PacketStats struct { + // LINT.IfChange(ICMPv4PacketStats) + // Echo is the total number of ICMPv4 echo packets counted. Echo *StatCounter @@ -1358,10 +1297,56 @@ type ICMPv4PacketStats struct { // InfoReply is the total number of ICMPv4 information reply packets // counted. InfoReply *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats) +} + +// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. +type ICMPv4SentPacketStats struct { + // LINT.IfChange(ICMPv4SentPacketStats) + + ICMPv4PacketStats + + // Dropped is the total number of ICMPv4 packets dropped due to link + // layer errors. + Dropped *StatCounter + + // RateLimited is the total number of ICMPv4 packets dropped due to + // rate limit being exceeded. + RateLimited *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats) +} + +// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. +type ICMPv4ReceivedPacketStats struct { + // LINT.IfChange(ICMPv4ReceivedPacketStats) + + ICMPv4PacketStats + + // Invalid is the total number of invalid ICMPv4 packets received. + Invalid *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats) +} + +// ICMPv4Stats collects ICMPv4-specific stats. +type ICMPv4Stats struct { + // LINT.IfChange(ICMPv4Stats) + + // PacketsSent contains statistics about sent packets. + PacketsSent ICMPv4SentPacketStats + + // PacketsReceived contains statistics about received packets. + PacketsReceived ICMPv4ReceivedPacketStats + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4Stats) } // ICMPv6PacketStats enumerates counts for all ICMPv6 packet types. type ICMPv6PacketStats struct { + // LINT.IfChange(ICMPv6PacketStats) + // EchoRequest is the total number of ICMPv6 echo request packets // counted. EchoRequest *StatCounter @@ -1416,32 +1401,14 @@ type ICMPv6PacketStats struct { // MulticastListenerDone is the total number of Multicast Listener Done // messages counted. MulticastListenerDone *StatCounter -} - -// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. -type ICMPv4SentPacketStats struct { - ICMPv4PacketStats - - // Dropped is the total number of ICMPv4 packets dropped due to link - // layer errors. - Dropped *StatCounter - // RateLimited is the total number of ICMPv6 packets dropped due to - // rate limit being exceeded. - RateLimited *StatCounter -} - -// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats. -type ICMPv4ReceivedPacketStats struct { - ICMPv4PacketStats - - // Invalid is the total number of ICMPv4 packets received that the - // transport layer could not parse. - Invalid *StatCounter + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats) } // ICMPv6SentPacketStats collects outbound ICMPv6-specific stats. type ICMPv6SentPacketStats struct { + // LINT.IfChange(ICMPv6SentPacketStats) + ICMPv6PacketStats // Dropped is the total number of ICMPv6 packets dropped due to link @@ -1451,47 +1418,41 @@ type ICMPv6SentPacketStats struct { // RateLimited is the total number of ICMPv6 packets dropped due to // rate limit being exceeded. RateLimited *StatCounter + + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats) } // ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats. type ICMPv6ReceivedPacketStats struct { + // LINT.IfChange(ICMPv6ReceivedPacketStats) + ICMPv6PacketStats // Unrecognized is the total number of ICMPv6 packets received that the // transport layer does not know how to parse. Unrecognized *StatCounter - // Invalid is the total number of ICMPv6 packets received that the - // transport layer could not parse. + // Invalid is the total number of invalid ICMPv6 packets received. Invalid *StatCounter // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets // dropped due to being router-specific packets. RouterOnlyPacketsDroppedByHost *StatCounter -} - -// ICMPv4Stats collects ICMPv4-specific stats. -type ICMPv4Stats struct { - // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type - // and a single count of packets which failed to write to the link - // layer. - PacketsSent ICMPv4SentPacketStats - // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4 - // packet type and a single count of invalid packets received. - PacketsReceived ICMPv4ReceivedPacketStats + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats) } // ICMPv6Stats collects ICMPv6-specific stats. type ICMPv6Stats struct { - // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type - // and a single count of packets which failed to write to the link - // layer. + // LINT.IfChange(ICMPv6Stats) + + // PacketsSent contains statistics about sent packets. PacketsSent ICMPv6SentPacketStats - // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6 - // packet type and a single count of invalid packets received. + // PacketsReceived contains statistics about received packets. PacketsReceived ICMPv6ReceivedPacketStats + + // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6Stats) } // ICMPStats collects ICMP-specific stats (both v4 and v6). @@ -1505,6 +1466,8 @@ type ICMPStats struct { // IGMPPacketStats enumerates counts for all IGMP packet types. type IGMPPacketStats struct { + // LINT.IfChange(IGMPPacketStats) + // MembershipQuery is the total number of Membership Query messages counted. MembershipQuery *StatCounter @@ -1518,22 +1481,29 @@ type IGMPPacketStats struct { // LeaveGroup is the total number of Leave Group messages counted. LeaveGroup *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats) } // IGMPSentPacketStats collects outbound IGMP-specific stats. type IGMPSentPacketStats struct { + // LINT.IfChange(IGMPSentPacketStats) + IGMPPacketStats // Dropped is the total number of IGMP packets dropped. Dropped *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats) } // IGMPReceivedPacketStats collects inbound IGMP-specific stats. type IGMPReceivedPacketStats struct { + // LINT.IfChange(IGMPReceivedPacketStats) + IGMPPacketStats - // Invalid is the total number of IGMP packets received that IGMP could not - // parse. + // Invalid is the total number of invalid IGMP packets received. Invalid *StatCounter // ChecksumErrors is the total number of IGMP packets dropped due to bad @@ -1543,21 +1513,27 @@ type IGMPReceivedPacketStats struct { // Unrecognized is the total number of unrecognized messages counted, these // are silently ignored for forward-compatibilty. Unrecognized *StatCounter + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats) } -// IGMPStats colelcts IGMP-specific stats. +// IGMPStats collects IGMP-specific stats. type IGMPStats struct { - // IGMPSentPacketStats contains counts of sent packets by IGMP packet type - // and a single count of invalid packets received. + // LINT.IfChange(IGMPStats) + + // PacketsSent contains statistics about sent packets. PacketsSent IGMPSentPacketStats - // IGMPReceivedPacketStats contains counts of received packets by IGMP packet - // type and a single count of invalid packets received. + // PacketsReceived contains statistics about received packets. PacketsReceived IGMPReceivedPacketStats + + // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { + // LINT.IfChange(IPStats) + // PacketsReceived is the total number of IP packets received from the // link layer. PacketsReceived *StatCounter @@ -1575,7 +1551,7 @@ type IPStats struct { InvalidSourceAddressesReceived *StatCounter // PacketsDelivered is the total number of incoming IP packets that - // are successfully delivered to the transport layer via HandlePacket. + // are successfully delivered to the transport layer. PacketsDelivered *StatCounter // PacketsSent is the total number of IP packets sent via WritePacket. @@ -1613,10 +1589,14 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + + // LINT.ThenChange(network/ip/stats.go:MultiCounterIPStats) } // ARPStats collects ARP-specific stats. type ARPStats struct { + // LINT.IfChange(ARPStats) + // PacketsReceived is the number of ARP packets received from the link layer. PacketsReceived *StatCounter @@ -1644,10 +1624,6 @@ type ARPStats struct { // ARP request with a bad local address. OutgoingRequestBadLocalAddressErrors *StatCounter - // OutgoingRequestNetworkUnreachableErrors is the number of failures to send - // an ARP request with a network unreachable error. - OutgoingRequestNetworkUnreachableErrors *StatCounter - // OutgoingRequestsDropped is the number of ARP requests which failed to write // to a link-layer endpoint. OutgoingRequestsDropped *StatCounter @@ -1666,6 +1642,8 @@ type ARPStats struct { // OutgoingRepliesSent is the number of ARP replies successfully written to a // link-layer endpoint. OutgoingRepliesSent *StatCounter + + // LINT.ThenChange(network/arp/stats.go:multiCounterARPStats) } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 1742a178d..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -7,6 +7,7 @@ go_test( size = "small", srcs = [ "forward_test.go", + "iptables_test.go", "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", @@ -16,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ac9670f9a..38e1881c7 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) { subTests := []struct { name string proto tcpip.TransportProtocolNumber - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ @@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) { { name: "TCP", proto: tcp.ProtocolNumber, - expectedConnectErr: tcpip.ErrConnectStarted, + expectedConnectErr: &tcpip.ErrConnectStarted{}, setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() @@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) { var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-ch continue } @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() @@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } - if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { - t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } } if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) @@ -436,9 +444,10 @@ func TestForwarding(t *testing.T) { write := func(ep tcpip.Endpoint, data []byte) { t.Helper() - dataPayload := tcpip.SlicePayload(data) + var r bytes.Reader + r.Reset(data) var wOpts tcpip.WriteOptions - n, err := ep.Write(dataPayload, wOpts) + n, err := ep.Write(&r, wOpts) if err != nil { t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) } @@ -486,7 +495,7 @@ func TestForwarding(t *testing.T) { read(serverCH, serverEP, data, clientAddr) - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + data = []byte{5, 6, 7, 8, 9, 10, 11, 12} write(serverEP, data) read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) }) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go new file mode 100644 index 000000000..21a8dd291 --- /dev/null +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -0,0 +1,336 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type inputIfNameMatcher struct { + name string +} + +var _ stack.Matcher = (*inputIfNameMatcher)(nil) + +func (*inputIfNameMatcher) Name() string { + return "inputIfNameMatcher" +} + +func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { + return (hook == stack.Input && im.name != "" && im.name == inNicName), false +} + +const ( + nicID = 1 + nicName = "nic1" + anotherNicName = "nic2" + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + srcAddrV4 = "\x0a\x00\x00\x01" + dstAddrV4 = "\x0a\x00\x00\x02" + srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + payloadSize = 20 +) + +func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err) + } + return s, e +} + +func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + e := channel.New(0, header.IPv4MinimumMTU, linkAddr) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err) + } + return s, e +} + +func genPacketV6() *stack.PacketBuffer { + pktSize := header.IPv6MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv6(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadSize, + TransportProtocol: 99, + HopLimit: 255, + SrcAddr: srcAddrV6, + DstAddr: dstAddrV6, + }) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func genPacketV4() *stack.PacketBuffer { + pktSize := header.IPv4MinimumSize + payloadSize + hdr := buffer.NewPrependable(pktSize) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(pktSize), + ID: 1, + Flags: 0, + FragmentOffset: 16, + TTL: 48, + Protocol: 99, + SrcAddr: srcAddrV4, + DstAddr: dstAddrV4, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + vv := hdr.View().ToVectorisedView() + return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) +} + +func TestIPTablesStatsForInput(t *testing.T) { + tests := []struct { + name string + setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) + setupFilter func(*testing.T, *stack.Stack) + genPacket func() *stack.PacketBuffer + proto tcpip.NetworkProtocolNumber + expectReceived int + expectInputDropped int + }{ + { + name: "IPv6 Accept", + setupStack: genStackV6, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept", + setupStack: genStackV4, + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface matches)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface matches)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv6 Drop (input interface does not match but invert is true)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv4 Drop (input interface does not match but invert is true)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ + InputInterface: anotherNicName, + InputInterfaceInvert: true, + } + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 1, + }, + { + name: "IPv6 Accept (input interface does not match using a matcher)", + setupStack: genStackV6, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + } + }, + genPacket: genPacketV6, + proto: header.IPv6ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + { + name: "IPv4 Accept (input interface does not match using a matcher)", + setupStack: genStackV4, + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Input] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + } + }, + genPacket: genPacketV4, + proto: header.IPv4ProtocolNumber, + expectReceived: 1, + expectInputDropped: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := test.setupStack(t) + test.setupFilter(t, s) + e.InjectInbound(test.proto, test.genPacket()) + + if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { + t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) + } + if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { + t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index af32d3009..7069352f2 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,11 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -32,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -207,8 +211,10 @@ func TestPing(t *testing.T) { defer ep.Close() icmpBuf := test.icmpBuf(t) + var r bytes.Reader + r.Reset(icmpBuf) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) @@ -251,7 +257,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error sockError tcpip.SockError }{ { @@ -270,9 +276,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv4 without resolvable remote", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv4DstUnreachable), ErrCode: byte(header.ICMPv4HostUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP, @@ -292,9 +298,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv6 without resolvable remote", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, + Err: &tcpip.ErrNoRoute{}, ErrType: byte(header.ICMPv6DstUnreachable), ErrCode: byte(header.ICMPv6AddressUnreachable), ErrOrigin: tcpip.SockExtErrorOriginICMP6, @@ -351,16 +357,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr - if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + { + err := clientEP.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) + } } // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch - var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + _, err := clientEP.Write(&r, wOpts) + if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { + t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) + } } if test.expectedWriteErr == nil { @@ -374,7 +388,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { sockErrCmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b *tcpip.Error) bool { + cmp.Comparer(func(a, b tcpip.Error) bool { // tcpip.Error holds an unexported field but the errors netstack uses // are pre defined so we can simply compare pointers. return a == b @@ -404,20 +418,134 @@ func TestGetLinkAddress(t *testing.T) { ) tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedLinkAddr bool + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedOk bool }{ { - name: "IPv4", + name: "IPv4 resolvable", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedOk: true, }, { - name: "IPv6", + name: "IPv6 resolvable", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedOk: true, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, useNeighborCache := range []bool{true, false} { + t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + UseNeighborCache: useNeighborCache, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) + } + }) + } +} + +func TestRouteResolvedFields(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + localAddr tcpip.Address + remoteAddr tcpip.Address + immediatelyResolvable bool + expectedSuccess bool + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 immediately resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv4AllSystems, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), + }, + { + name: "IPv6 immediately resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: header.IPv6AllNodesMulticastAddress, + immediatelyResolvable: true, + expectedSuccess: true, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, + { + name: "IPv4 resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv6 resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: true, + expectedLinkAddr: linkAddr2, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + localAddr: ipv4Addr1.AddressWithPrefix.Address, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + localAddr: ipv6Addr1.AddressWithPrefix.Address, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + immediatelyResolvable: false, + expectedSuccess: false, }, } @@ -431,28 +559,618 @@ func TestGetLinkAddress(t *testing.T) { } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr + + ch := make(chan stack.ResolvedFieldsResult, 1) + + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" - for i := 0; i < 2; i++ { - addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) - var want *tcpip.Error - if i == 0 { - want = tcpip.ErrWouldBlock + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if err != want { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) } - if i == 0 { - <-ch - continue + if !test.expectedSuccess { + return } - if addr != linkAddr2 { - t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected route to be immediately resolvable") } }) } }) } } + +func TestWritePacketsLinkResolution(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedWriteErr tcpip.Error + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) + } + defer serverEP.Close() + + serverAddr := tcpip.FullAddress{Port: 1234} + if err := serverEP.Bind(serverAddr); err != nil { + t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) + } + + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() + + data := []byte{1, 2} + var pkts stack.PacketBufferList + for _, d := range data { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: buffer.View([]byte{d}).ToVectorisedView(), + }) + pkt.TransportProtocolNumber = udp.ProtocolNumber + length := uint16(pkt.Size()) + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: serverAddr.Port, + Length: length, + }) + xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + + pkts.PushBack(pkt) + } + + params := stack.NetworkHeaderParams{ + Protocol: udp.ProtocolNumber, + TTL: 64, + TOS: stack.DefaultTOS, + } + + if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { + t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + } else if want := pkts.Len(); want != n { + t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + } + + var writer bytes.Buffer + count := 0 + for { + var rOpts tcpip.ReadOptions + res, err := serverEP.Read(&writer, rOpts) + if err != nil { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + // Should not have anymore bytes to read after we read the sent + // number of bytes. + if count == len(data) { + break + } + + <-serverCH + continue + } + + t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) + } + count += res.Count + } + + if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { + t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if diff := cmp.Diff(data, writer.Bytes()); diff != "" { + t.Errorf("read bytes mismatch (-want +got):\n%s", diff) + } + }) + } +} + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + { + err := clientEP.Connect(listenerAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) + } + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3b13ba04d..ab67762ef 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { @@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, err := sep.Write(tcpip.SlicePayload(data), wopts) + var r bytes.Reader + r.Reset(data) + n, err := sep.Write(&r, wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } @@ -260,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -320,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + { + err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + } } } @@ -468,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { Addr: test.dstAddr, Port: localPort, } - if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + { + err := connectingEndpoint.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } } if !test.expectAccept { - if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + _, _, err := listeningEndpoint.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } return } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index ce7c16bd1..d685fdd36 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, err := wep.ep.Write(data, writeOpts); err != nil { + data := []byte{byte(i), 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := wep.ep.Write(&r, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) @@ -759,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + { + _, err := ep.Read(&buf, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) + } } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index b222d2b05..9654c9527 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) { linkEndpoint func() stack.LinkEndpoint localAddr tcpip.Address icmpBuf func(*testing.T) buffer.View - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) }{ { @@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, { @@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, } @@ -186,17 +186,22 @@ func TestLocalPing(t *testing.T) { defer ep.Close() connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + { + err := ep.Connect(connAddr) + if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) + } } if test.expectedConnectErr != nil { return } - payload := tcpip.SlicePayload(test.icmpBuf(t)) + payload := test.icmpBuf(t) + var r bytes.Reader + r.Reset(payload) var wOpts tcpip.WriteOptions - if n, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) @@ -261,12 +266,12 @@ func TestLocalUDP(t *testing.T) { subTests := []struct { name string addAddress bool - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "Unassigned local address", addAddress: false, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, }, { name: "Assigned local address", @@ -329,12 +334,14 @@ func TestLocalUDP(t *testing.T) { Port: 80, } - clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + clientPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(clientPayload) wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the @@ -376,12 +383,14 @@ func TestLocalUDP(t *testing.T) { } } - serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + serverPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(serverPayload) wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(&r, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..3cf05520d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -69,8 +69,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int + mu sync.RWMutex `state:"nosave"` // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState @@ -85,7 +84,7 @@ type endpoint struct { ops tcpip.SocketOptions } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.SetSendBufferSize(32*1024, false /* notify */) + + // Override with stack defaults. + var ss tcpip.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) + } return ep, nil } @@ -119,7 +124,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -154,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -188,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -199,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: case stateConnected: @@ -207,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case stateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -236,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -257,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -270,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -292,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + var err tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) @@ -334,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.TTLOption: e.mu.Lock() @@ -351,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSize - e.mu.Unlock() - return v, nil case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() @@ -381,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -410,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } icmpv4.SetChecksum(0) @@ -424,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -441,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } dataVV := data.ToVectorisedView() @@ -456,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { return tcpip.FullAddress{}, 0, err @@ -465,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -485,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -535,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() e.shutdownFlags |= flags if e.state != stateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } if flags&tcpip.ShutdownRead != 0 { @@ -565,31 +566,31 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) - switch err { + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + switch err.(type) { case nil: return true, nil - case tcpip.ErrPortInUse: + case *tcpip.ErrPortInUse: return false, nil default: return false, err @@ -599,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.state != stateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -619,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { if len(addr.Addr) != 0 { // A local address was specified, verify that it's valid. if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -647,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -663,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -675,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.state != stateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -805,7 +806,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 9d263c0ec..c9fa9974a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -69,12 +69,13 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) if e.state != stateBound && e.state != stateConnected { return } - var err *tcpip.Error + var err tcpip.Error if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { @@ -84,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.ID.LocalAddress = e.route.LocalAddress } else if len(e.ID.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3820e5dc7..47f7dd1cb 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } @@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int { } // ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. -func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { switch p.number { case ProtocolNumber4: hdr := header.ICMPv4(v) @@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..73bb66830 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,24 +79,22 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool - boundNIC tcpip.NICID + mu sync.RWMutex `state:"nosave"` + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID // lastErrorMu protects lastError. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // ops is used to get socket level options. ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. -func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb netProto: netProto, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - ep.sndBufSizeMax = ss.Default + ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -162,16 +159,16 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if ep.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if ep.rcvClosed { ep.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } ep.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -201,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be -// connected, and this function always returnes tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return tcpip.ErrNotSupported +// connected, and this function always returnes *tcpip.ErrNotSupported. +func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used -// with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return tcpip.ErrNotSupported +// with Shutdown, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with -// Listen, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +// Listen, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with -// Accept, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +// Accept, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { // TODO(gvisor.dev/issue/173): Add Bind support. // "By default, all packets of the specified protocol type are passed @@ -277,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -306,38 +303,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns -// tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +// *tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := ep.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - ep.mu.Lock() - ep.sndBufSizeMax = v - ep.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -357,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (ep *endpoint) LastError() *tcpip.Error { +func (ep *endpoint) LastError() tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -371,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (ep *endpoint) UpdateLastError(err *tcpip.Error) { +func (ep *endpoint) UpdateLastError(err tcpip.Error) { ep.lastErrorMu.Lock() ep.lastError = err ep.lastErrorMu.Unlock() } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrNotSupported{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSizeMax - ep.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: ep.rcvMu.Lock() v := ep.rcvBufSizeMax @@ -408,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index e2fa96d17..ece662c0d 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -63,29 +63,11 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. ep.stack = stack.StackFromEnv + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(*err) + panic(err) } } - -// saveLastError is invoked by stateify. -func (ep *endpoint) saveLastError() string { - if ep.lastError == nil { - return "" - } - - return ep.lastError.String() -} - -// loadLastError is invoked by stateify. -func (ep *endpoint) loadLastError(s string) { - if s == "" { - return - } - - ep.lastError = tcpip.StringToError(s) -} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..9c9ccc0ff 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -76,12 +76,10 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` @@ -95,13 +93,13 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } e := &endpoint{ @@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, associated: associated, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetHeaderIncluded(!associated) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return nil, err } @@ -159,7 +157,7 @@ func (e *endpoint) Close() { return } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -191,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -227,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } } n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -267,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} } // If this is an unassociated socket and callee provided a nonzero @@ -290,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -311,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, tcpip.ErrDestinationRequired + return 0, &tcpip.ErrDestinationRequired{} } return e.finishWrite(payloadBytes, e.route) @@ -321,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } // Find the route to the destination. If BindAddress is 0, @@ -338,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), @@ -365,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } e.mu.Lock() defer e.mu.Unlock() if e.closed { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nic := addr.NIC @@ -395,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } } @@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { route.Release() return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = nic } - // Save the route we've connected via. + if e.route != nil { + // If the endpoint was previously connected then release any previous route. + e.route.Release() + } e.route = route e.connected = true @@ -423,42 +424,42 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() if !e.connected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { + return &tcpip.ErrBadLocalAddress{} } if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) e.RegisterNICID = addr.NIC e.BindNICID = addr.NIC } @@ -470,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -498,37 +499,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) - } - if v > ss.Max { - v = ss.Max - } - if v < ss.Min { - v = ss.Min - } - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil - case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -548,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -583,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } @@ -703,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 4a7e1c039..263ec5146 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -69,10 +69,11 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) // If the endpoint is connected, re-connect. if e.connected { - var err *tcpip.Error + var err tcpip.Error // TODO(gvisor.dev/issue/4906): Properly restore the route with the right // remote address. We used to pass e.remote.RemoteAddress which was // effectively the empty address but since moving e.route to hold a pointer @@ -88,12 +89,12 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if e.bound { if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index f30aa2a4a..e393b993d 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -25,11 +25,11 @@ import ( type EndpointFactory struct{} // NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. -func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } // NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. -func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7e81203ba..fcdd032c5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -99,7 +99,6 @@ go_test( "//pkg/rand", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6921de0f1..842c1622b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // On success, a handshake h is returned with h.ep.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } l.addPendingEndpoint(ep) @@ -281,14 +281,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q l.removePendingEndpoint(ep) - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } deferAccept = l.listenEP.deferAccept } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { ep.mu.Unlock() ep.Close() @@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // established endpoint is returned with e.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { h, err := l.startHandshake(s, opts, queue, owner) if err != nil { return nil, err @@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error { defer s.decRef() h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) @@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // and needs to handle it. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f711cd4df..4695b66d6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool { // synSentState handles a segment received when the TCP 3-way handshake is in // the SYN-SENT state. -func (h *handshake) synSentState(s *segment) *tcpip.Error { +func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { @@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.ep.workerCleanup = true // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. -func (h *handshake) synRcvdState(s *segment) *tcpip.Error { +func (h *handshake) synRcvdState(s *segment) tcpip.Error { if s.flagIsSet(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -349,7 +349,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) if !h.active { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } h.resetState() @@ -412,7 +412,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } -func (h *handshake) handleSegment(s *segment) *tcpip.Error { +func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) @@ -429,7 +429,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error { // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). -func (h *handshake) processSegments() *tcpip.Error { +func (h *handshake) processSegments() tcpip.Error { for i := 0; i < maxSegmentsPerWake; i++ { s := h.ep.segmentQueue.dequeue() if s == nil { @@ -505,7 +505,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { +func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} @@ -555,7 +555,7 @@ func (h *handshake) complete() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { - return tcpip.ErrAborted + return &tcpip.ErrAborted{} } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { @@ -593,19 +593,19 @@ type backoffTimer struct { t *time.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { - return nil, tcpip.ErrTimeout + return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} bt.t = time.AfterFunc(timeout, f) return bt, nil } -func (bt *backoffTimer) reset() *tcpip.Error { +func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 if bt.timeout > MaxRTO { - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) return nil @@ -706,7 +706,7 @@ type tcpFields struct { txHash uint32 } -func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { @@ -716,7 +716,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -803,7 +803,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 @@ -875,7 +875,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] @@ -895,55 +895,60 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn return err } -func (e *endpoint) handleWrite() *tcpip.Error { - // Move packets from send queue to send list. The queue is accessible - // from other goroutines and protected by the send mutex, while the send - // list is only accessible from the handler goroutine, so it needs no - // mutexes. +func (e *endpoint) handleWrite() { e.sndBufMu.Lock() + next := e.drainSendQueueLocked() + e.sndBufMu.Unlock() + + e.sendData(next) +} +// Move packets from send queue to send list. +// +// Precondition: e.sndBufMu must be locked. +func (e *endpoint) drainSendQueueLocked() *segment { first := e.sndQueue.Front() if first != nil { e.snd.writeList.PushBackList(&e.sndQueue) e.sndBufInQueue = 0 } + return first +} - e.sndBufMu.Unlock() - +// Precondition: e.mu must be locked. +func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { - e.snd.writeNext = first + e.snd.writeNext = next } // Push out any new packets. e.snd.sendData() - - return nil } -func (e *endpoint) handleClose() *tcpip.Error { +func (e *endpoint) handleClose() { if !e.EndpointState().connected() { - return nil + return } // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - - return nil } // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This // method must only be called from the protocol goroutine. -func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { +func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) e.hardError = err - if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + switch err.(type) { + case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout: + default: // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1053,7 +1058,7 @@ func (e *endpoint) drainClosingSegmentQueue() { } } -func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { +func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are @@ -1081,7 +1086,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.hardError = tcpip.ErrAborted + e.hardError = &tcpip.ErrAborted{} e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1094,14 +1099,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // handleSegment is invoked from the processor goroutine // rather than the worker goroutine. e.notifyProtocolGoroutine(notifyResetByPeer) - return false, tcpip.ErrConnectionReset + return false, &tcpip.ErrConnectionReset{} } } return true, nil } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { +func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1148,7 +1153,7 @@ func (e *endpoint) probeSegment() { // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { +func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. defer e.probeSegment() @@ -1222,7 +1227,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. -func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { +func (e *endpoint) keepaliveTimerExpired() tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() @@ -1236,13 +1241,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -1286,7 +1291,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1332,6 +1337,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + // Reaching this point means that we successfully completed the 3-way + // handshake with our peer. + // + // Completing the 3-way handshake is an indication that the route is valid + // and the remote is reachable as the only way we can complete a handshake + // is if our SYN reached the remote and their ACK reached us. + e.route.ConfirmReachable() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1344,19 +1357,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // wakes up. funcs := []struct { w *sleep.Waker - f func() *tcpip.Error + f func() tcpip.Error }{ { w: &e.sndWaker, - f: e.handleWrite, + f: func() tcpip.Error { + e.handleWrite() + return nil + }, }, { w: &e.sndCloseWaker, - f: e.handleClose, + f: func() tcpip.Error { + e.handleClose() + return nil + }, }, { w: &closeWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. @@ -1367,10 +1386,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.snd.resendWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { if !e.snd.retransmitTimerExpired() { e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } return nil }, @@ -1381,7 +1400,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.newSegmentWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { return e.handleSegments(false /* fastPath */) }, }, @@ -1391,7 +1410,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.notificationWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -1408,11 +1427,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} } if n¬ifyResetByPeer != 0 { - return tcpip.ErrConnectionReset + return &tcpip.ErrConnectionReset{} } if n¬ifyClose != 0 && closeTimer == nil { @@ -1491,7 +1510,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. - cleanupOnError := func(err *tcpip.Error) { + cleanupOnError := func(err tcpip.Error) { e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 1d1b01a6c..2d90246e4 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "strings" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { + if _, ok := err.(*tcpip.ErrNoRoute); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } } @@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) { t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) } + var r strings.Reader data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + r.Reset(data) + nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() tcp = header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -523,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -547,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) @@ -611,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -633,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..6e4e26c39 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -386,12 +386,12 @@ type endpoint struct { // hardError is meaningful only when state is stateError. It stores the // error to be returned when read/write syscalls are called and the // endpoint is in this state. hardError is protected by endpoint mu. - hardError *tcpip.Error `state:".(string)"` + hardError tcpip.Error // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // rcvReadMu synchronizes calls to Read. // @@ -557,7 +557,6 @@ type endpoint struct { // When the send side is closed, the protocol goroutine is notified via // sndCloseWaker, and sndClosed is set to true. sndBufMu sync.Mutex `state:"nosave"` - sndBufSize int sndBufUsed int sndClosed bool sndBufInQueue seqnum.Size @@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, - sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. @@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) + e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - e.sndBufSize = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() - if e.sndClosed || e.sndBufUsed < e.sndBufSize { + sndBufSize := e.getSendBufferSize() + if e.sndClosed || e.sndBufUsed < sndBufSize { result |= waiter.EventOut } e.sndBufMu.Unlock() @@ -1059,7 +1059,7 @@ func (e *endpoint) Close() { if isResetState { // Close the endpoint without doing full shutdown and // send a RST. - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.closeNoShutdownLocked() // Wake up worker to close the endpoint. @@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1293,14 +1293,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) hardErrorLocked() *tcpip.Error { +func (e *endpoint) hardErrorLocked() tcpip.Error { err := e.hardError e.hardError = nil return err } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) lastErrorLocked() *tcpip.Error { +func (e *endpoint) lastErrorLocked() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1309,7 +1309,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error { } // LastError implements tcpip.Endpoint.LastError. -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.LockUser() defer e.UnlockUser() if err := e.hardErrorLocked(); err != nil { @@ -1319,7 +1319,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() e.lastErrorMu.Lock() e.lastError = err @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1337,7 +1337,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { - if serr == tcpip.ErrClosedForReceive { + if _, ok := serr.(*tcpip.ErrClosedForReceive); ok { e.stats.ReadErrors.ReadClosed.Increment() } return tcpip.ReadResult{}, serr @@ -1377,7 +1377,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // If something is read, we must report it. Report error when nothing is read. if done == 0 && err != nil { - return tcpip.ReadResult{}, tcpip.ErrBadBuffer + return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{} } return tcpip.ReadResult{ Count: done, @@ -1389,7 +1389,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // inclusive range of segments that can be read. // // Precondition: e.rcvReadMu must be held. -func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { +func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1398,7 +1398,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } // The endpoint can be read if it's connected, or if it's already closed @@ -1414,17 +1414,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { if err := e.hardErrorLocked(); err != nil { return nil, nil, err } - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } e.stats.ReadErrors.NotConnected.Increment() - return nil, nil, tcpip.ErrNotConnected + return nil, nil, &tcpip.ErrNotConnected{} } if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } return e.rcvList.Front(), e.rcvList.Back(), nil @@ -1476,106 +1476,117 @@ func (e *endpoint) commitRead(done int) *segment { // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu -func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { +func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: if err := e.hardErrorLocked(); err != nil { return 0, err } - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case !s.connecting() && !s.connected(): - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case s.connecting(): // As per RFC793, page 56, a send request arriving when in connecting // state, can be queued to be completed after the state becomes // connected. Return an error code for the caller of endpoint Write to // try again, until the connection handshake is complete. - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } // Check if the connection has already been closed for sends. if e.sndClosed { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } - avail := e.sndBufSize - e.sndBufUsed + sndBufSize := e.getSendBufferSize() + avail := sndBufSize - e.sndBufUsed if avail <= 0 { - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } return avail, nil } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. e.LockUser() - e.sndBufMu.Lock() - - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, err - } - - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - - // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. - if opts.Atomic { - e.sndBufMu.Unlock() - e.UnlockUser() - } - return 0, perr - } + defer e.UnlockUser() - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() + nextSeg, n, err := func() (*segment, int, tcpip.Error) { e.sndBufMu.Lock() + defer e.sndBufMu.Unlock() + avail, err := e.isEndpointWritableLocked() if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() e.stats.WriteErrors.WriteClosed.Increment() - return 0, err + return nil, 0, err } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + v, err := func() ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + defer e.sndBufMu.Lock() + + e.UnlockUser() + defer e.LockUser() + } + + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { + return nil, &tcpip.ErrBadBuffer{} + } + return v, nil + }() + if len(v) == 0 || err != nil { + return nil, 0, err + } + + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } - } - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil + return e.drainSendQueueLocked(), len(v), nil + }() + if err != nil { + return 0, err + } + e.sendData(nextSeg) + return int64(n), nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1682,8 +1693,16 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } } +func (e *endpoint) getSendBufferSize() int { + sndBufSize, err := e.ops.GetSendBufferSize() + if err != nil { + panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) + } + return int(sndBufSize) +} + // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 const inetECNMask = 3 @@ -1711,7 +1730,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.MaxSegOption: userMSS := v if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.userMSS = uint16(userMSS) @@ -1722,7 +1741,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Return not supported if attempting to set this option to // anything other than path MTU discovery disabled. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.ReceiveBufferSizeOption: @@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) - } - - if v > ss.Max { - v = ss.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < ss.Min { - v = ss.Min - } - } else { - v = math.MaxInt32 - } - - e.sndBufMu.Lock() - e.sndBufSize = v - e.sndBufMu.Unlock() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1807,7 +1801,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.TCPSynCountOption: if v < 1 || v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.maxSynRetries = uint8(v) @@ -1823,7 +1817,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: e.UnlockUser() - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -1844,7 +1838,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -1890,7 +1884,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPLingerTimeoutOption: e.LockUser() @@ -1933,13 +1927,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // readyReceiveSize returns the number of bytes ready to be received. -func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { +func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { e.LockUser() defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } e.rcvListMu.Lock() @@ -1949,7 +1943,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.KeepaliveCountOption: e.keepalive.Lock() @@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - v := e.sndBufSize - e.sndBufMu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvListMu.Lock() v := e.rcvBufSize @@ -2019,24 +2007,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 1, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} + } +} + +func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { + info := tcpip.TCPInfoOption{} + e.LockUser() + snd := e.snd + if snd != nil { + // We do not calculate RTT before sending the data packets. If + // the connection did not send and receive data, then RTT will + // be zero. + snd.rtt.Lock() + info.RTT = snd.rtt.srtt + info.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + + info.RTO = snd.rto + info.CcState = snd.state + info.SndSsthresh = uint32(snd.sndSsthresh) + info.SndCwnd = uint32(snd.sndCwnd) + info.ReorderSeen = snd.rc.reorderSeen } + e.UnlockUser() + return info } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.TCPInfoOption: - *o = tcpip.TCPInfoOption{} - e.LockUser() - snd := e.snd - e.UnlockUser() - if snd != nil { - snd.rtt.Lock() - o.RTT = snd.rtt.srtt - o.RTTVar = snd.rtt.rttvar - snd.rtt.Unlock() - } + *o = e.getTCPInfo() case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -2082,14 +2084,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -2098,18 +2100,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { err := e.connect(addr, true, true) - if err != nil && !err.IgnoreStats() { - // Connect failed. Let's wake up any waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } @@ -2120,7 +2124,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2139,7 +2143,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return nil } // Otherwise return that it's already connected. - return tcpip.ErrAlreadyConnected + return &tcpip.ErrAlreadyConnected{} } nicID := addr.NIC @@ -2152,7 +2156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if nicID != 0 && nicID != e.boundNICID { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } nicID = e.boundNICID @@ -2164,16 +2168,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc case StateConnecting, StateSynSent, StateSynRecv: // A connection request has already been issued but hasn't completed // yet. - return tcpip.ErrAlreadyConnecting + return &tcpip.ErrAlreadyConnecting{} case StateError: if err := e.hardErrorLocked(); err != nil { return err } - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Find a route to the desired destination. @@ -2190,7 +2194,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2229,12 +2233,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { - if err != tcpip.ErrPortInUse || !reuse { + if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } transEPID := e.ID @@ -2278,9 +2282,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) - if err == tcpip.ErrPortInUse { + if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } return false, err @@ -2335,23 +2339,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - return tcpip.ErrConnectStarted + return &tcpip.ErrConnectStarted{} } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection to its // peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() return e.shutdownLocked(flags) } -func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.shutdownFlags |= flags switch { case e.EndpointState().connected(): @@ -2366,7 +2370,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) return nil @@ -2380,7 +2384,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } @@ -2413,22 +2417,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } return nil default: - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) tcpip.Error { err := e.listen(backlog) - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } -func (e *endpoint) listen(backlog int) *tcpip.Error { +func (e *endpoint) listen(backlog int) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2446,7 +2452,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Adjust the size of the channel iff we can fix // existing pending connections into the new one. if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } if cap(e.acceptedChan) == backlog { return nil @@ -2478,11 +2484,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Endpoint must be bound before it can transition to listen mode. if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } @@ -2518,7 +2524,7 @@ func (e *endpoint) startAcceptedLoop() { // to an endpoint previously set to listen mode. // // addr if not-nil will contain the peer address of the returned endpoint. -func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2527,7 +2533,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { - return nil, nil, tcpip.ErrInvalidEndpointState + return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. @@ -2538,7 +2544,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. case n = <-e.acceptedChan: e.acceptCond.Signal() default: - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } if peerAddr != nil { *peerAddr = n.getRemoteAddress() @@ -2547,19 +2553,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) { e.LockUser() defer e.UnlockUser() return e.bindLocked(addr) } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. if e.EndpointState() != StateInitial { - return tcpip.ErrAlreadyBound + return &tcpip.ErrAlreadyBound{} } e.BindAddr = addr.Addr @@ -2587,7 +2593,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { if len(addr.Addr) != 0 { nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } e.ID.LocalAddress = addr.Addr } @@ -2604,7 +2610,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2628,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2640,12 +2646,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() if !e.EndpointState().connected() { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return e.getRemoteAddress(), nil @@ -2677,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2726,26 +2732,27 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlAddressUnreachable: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } // updateSndBufferUsage is called by the protocol goroutine when room opens up // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { + sendBufferSize := e.getSendBufferSize() e.sndBufMu.Lock() - notify := e.sndBufUsed >= e.sndBufSize>>1 + notify := e.sndBufUsed >= sendBufferSize>>1 e.sndBufUsed -= v - // We only notify when there is half the sndBufSize available after + // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < e.sndBufSize>>1 + notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 e.sndBufMu.Unlock() if notify { @@ -2957,8 +2964,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState { s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() // Copy endpoint send state. + sndBufSize := e.getSendBufferSize() e.sndBufMu.Lock() - s.SndBufSize = e.sndBufSize + s.SndBufSize = sndBufSize s.SndBufUsed = e.sndBufUsed s.SndClosed = e.sndClosed s.SndBufInQueue = e.sndBufInQueue @@ -3023,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState { rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, + ReoWnd: rc.reoWnd, + ReoWndIncr: rc.reoWndIncr, + ReoWndPersist: rc.reoWndPersist, + RTTSeq: rc.rttSeq, } return s } @@ -3103,3 +3115,17 @@ func (e *endpoint) Wait() { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// GetTCPSendBufferLimits is used to get send buffer size limits for TCP. +func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.SendBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ba67176b5..c21dbc682 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,9 +55,11 @@ func (e *endpoint) beforeSave() { case epState.connected() || epState.handshake(): if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) + panic(&tcpip.ErrSaveRejection{ + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + }) } - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.mu.Unlock() e.Close() e.mu.Lock() @@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + sendBufferSize := e.getSendBufferSize() + if sendBufferSize < ss.Min || sendBufferSize > ss.Max { + panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max)) } } @@ -228,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) { // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } e.mu.Lock() @@ -265,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -292,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // saveRecentTSTime is invoked by stateify. func (e *endpoint) saveRecentTSTime() unixTime { return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} @@ -320,24 +308,6 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { e.recentTSTime = time.Unix(unix.second, unix.nano) } -// saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { - return "" - } - - return e.hardError.String() -} - -// loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { - if s == "" { - return - } - - e.hardError = tcpip.StringToError(s) -} - // saveMeasureTime is invoked by stateify. func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 596178625..2f9fe7ee0 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // CreateEndpoint creates a TCP endpoint for the connection request, performing // the 3-way handshake in the process. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { r.mu.Lock() defer r.mu.Unlock() if r.segment == nil { - return nil, tcpip.ErrInvalidEndpointState + return nil, &tcpip.ErrInvalidEndpointState{} } f := r.forwarder diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1720370c9..04012cd40 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } @@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given tcp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.TCP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.Lock() @@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.sendBufferSize = *v @@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.recvBufferSize = *v @@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi } // linux returns ENOENT when an invalid congestion control // is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() @@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPTimeWaitReuseOption: if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.timeWaitReuse = *v @@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSynRetriesOption: if *v < 1 || *v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.synRetries = uint8(*v) @@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.RLock() @@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 307bacca5..d85cb405a 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -22,12 +22,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) -// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as -// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. -// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is -// 1, PTO is inflated by WCDelAckT time to compensate for a potential long -// delayed ACK timer at the receiver. -const wcDelayedACKTimeout = 200 * time.Millisecond +const ( + // wcDelayedACKTimeout is the recommended maximum delayed ACK timer + // value as defined in the RFC. It stands for worst case delayed ACK + // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by + // WCDelAckT time to compensate for a potential long delayed ACK timer + // at the receiver. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. + wcDelayedACKTimeout = 200 * time.Millisecond + + // tcpRACKRecoveryThreshold is the number of loss recoveries for which + // the reorder window is inflated and after that the reorder window is + // reset to its initial value of minRTT/4. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2. + tcpRACKRecoveryThreshold = 16 +) // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or @@ -44,6 +53,11 @@ type rackControl struct { // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // exitedRecovery indicates if the connection is exiting loss recovery. + // This flag is set if the sender is leaving the recovery after + // receiving an ACK and is reset during updating of reorder window. + exitedRecovery bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -51,15 +65,30 @@ type rackControl struct { // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // reoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + reoWnd time.Duration + + // reoWndIncr is the multiplier applied to adjust reorder window. + reoWndIncr uint8 + + // reoWndPersist is the number of loss recoveries before resetting + // reorder window. + reoWndPersist int8 + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool + // rttSeq is the SND.NXT when rtt is updated. + rttSeq seqnum.Value // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` @@ -75,29 +104,36 @@ type rackControl struct { // tlpHighRxt the value of sender.sndNxt at the time of sending // a TLP retransmission. tlpHighRxt seqnum.Value + + // snd is a reference to the sender. + snd *sender } // init initializes RACK specific fields. -func (rc *rackControl) init() { +func (rc *rackControl) init(snd *sender, iss seqnum.Value) { + rc.fack = iss + rc.reoWndIncr = 1 + rc.snd = snd rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. -// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 +func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) + tsOffset := rc.snd.ep.tsOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: - // 1. When Timestamping option is available, if the TSVal is less than the - // transmit time of the most recent retransmitted packet. + // 1. When Timestamping option is available, if the TSVal is less than + // the transmit time of the most recent retransmitted packet. // 2. When RTT calculated for the packet is less than the smoothed RTT // for the connection. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // step 2 if seg.xmitCount > 1 { if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { - if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) { return } } @@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) { } } -// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. -func (rc *rackControl) setDSACKSeen() { - rc.dsackSeen = true +func (rc *rackControl) setDSACKSeen(dsackSeen bool) { + rc.dsackSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool { // The connection supports SACK. s.ep.sackPermitted && // The connection is not in loss recovery. - (s.state != RTORecovery && s.state != SACKRecovery) && + (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. s.ep.scoreboard.Sacked() == 0 } @@ -193,7 +228,7 @@ func (s *sender) schedulePTO() { // probeTimerExpired is the same as TLP_send_probe() as defined in // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. -func (s *sender) probeTimerExpired() *tcpip.Error { +func (s *sender) probeTimerExpired() tcpip.Error { if !s.rc.probeTimer.checkExpiration() { return nil } @@ -272,3 +307,82 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { } } } + +// updateRACKReorderWindow updates the reorder window. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 4: Update RACK reordering window +// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as +// an allowance for settling time before marking a packet lost. RACK starts +// initially with a conservative window of min_RTT/4. If no reordering has +// been observed RACK uses reo_wnd of zero during loss recovery, in order to +// retransmit quickly, or when the number of DUPACKs exceeds the classic +// DUPACKthreshold. +func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { + dsackSeen := rc.dsackSeen + snd := rc.snd + + // React to DSACK once per round trip. + // If SND.UNA < RACK.rtt_seq: + // RACK.dsack = false + if snd.sndUna.LessThan(rc.rttSeq) { + dsackSeen = false + } + + // If RACK.dsack: + // RACK.reo_wnd_incr += 1 + // RACK.dsack = false + // RACK.rtt_seq = SND.NXT + // RACK.reo_wnd_persist = 16 + if dsackSeen { + rc.reoWndIncr++ + dsackSeen = false + rc.rttSeq = snd.sndNxt + rc.reoWndPersist = tcpRACKRecoveryThreshold + } else if rc.exitedRecovery { + // Else if exiting loss recovery: + // RACK.reo_wnd_persist -= 1 + // If RACK.reo_wnd_persist <= 0: + // RACK.reo_wnd_incr = 1 + rc.reoWndPersist-- + if rc.reoWndPersist <= 0 { + rc.reoWndIncr = 1 + } + rc.exitedRecovery = false + } + + // Reorder window is zero during loss recovery, or when the number of + // DUPACKs exceeds the classic DUPACKthreshold. + // If RACK.reord is FALSE: + // If in loss recovery: (If in fast or timeout recovery) + // RACK.reo_wnd = 0 + // Return + // Else if RACK.pkts_sacked >= RACK.dupthresh: + // RACK.reo_wnd = 0 + // return + if !rc.reorderSeen { + if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { + rc.reoWnd = 0 + return + } + + if snd.sackedOut >= nDupAckThreshold { + rc.reoWnd = 0 + return + } + } + + // Calculate reorder window. + // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr + // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) + snd.rtt.Lock() + srtt := snd.rtt.srtt + snd.rtt.Unlock() + rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) + if srtt < rc.reoWnd { + rc.reoWnd = srtt + } +} + +func (rc *rackControl) exitRecovery() { + rc.exitedRecovery = true +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 405a6dce7..7a7c402c4 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -347,7 +347,7 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { r.ep.rcvListMu.Lock() rcvClosed := r.ep.rcvClosed || r.closed r.ep.rcvListMu.Unlock() @@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { break @@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } } @@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 079d90848..dfc8fd248 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -48,28 +48,6 @@ const ( MaxRetries = 15 ) -// ccState indicates the current congestion control state for this sender. -type ccState int - -const ( - // Open indicates that the sender is receiving acks in order and - // no loss or dupACK's etc have been detected. - Open ccState = iota - // RTORecovery indicates that an RTO has occurred and the sender - // has entered an RTO based recovery phase. - RTORecovery - // FastRecovery indicates that the sender has entered FastRecovery - // based on receiving nDupAck's. This state is entered only when - // SACK is not in use. - FastRecovery - // SACKRecovery indicates that the sender has entered SACK based - // recovery. - SACKRecovery - // Disorder indicates the sender either received some SACK blocks - // or dupACK's. - Disorder -) - // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { @@ -204,7 +182,7 @@ type sender struct { maxSentAck seqnum.Value // state is the current state of congestion control for this endpoint. - state ccState + state tcpip.CongestionControlState // cc is the congestion control algorithm in use for this sender. cc congestionControl @@ -280,14 +258,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, - rc: rackControl{ - fack: iss, - }, gso: ep.gso != nil, } - s.rc.init() - if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } @@ -295,6 +268,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) s.lr = s.initLossRecovery() + s.rc.init(s, iss) // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. @@ -593,7 +567,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } - s.state = RTORecovery + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -1018,7 +992,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -1062,14 +1036,14 @@ func (s *sender) enterRecovery() { s.fr.highRxt = s.sndUna s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { - s.state = SACKRecovery + s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. s.rc.tlpRxtOut = false return } - s.state = FastRecovery + s.state = tcpip.FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -1080,7 +1054,6 @@ func (s *sender) leaveRecovery() { // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh - s.cc.PostRecovery() } @@ -1166,7 +1139,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() - s.state = Disorder + s.state = tcpip.Disorder return false } @@ -1217,11 +1190,13 @@ func (s *sender) isDupAck(seg *segment) bool { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + s.rc.setDSACKSeen(false) + // Look for DSACK block. idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { - s.rc.setDSACKSeen() + s.rc.setDSACKSeen(true) idx = 1 n-- } @@ -1242,7 +1217,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { for _, sb := range sackBlocks { for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true s.sackedOut += s.pCount(seg, s.maxPayloadSize) @@ -1412,6 +1387,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { acked := s.sndUna.Size(ack) s.sndUna = ack + // The remote ACK-ing at least 1 byte is an indication that we have a + // full-duplex connection to the remote as the only way we will receive an + // ACK is if the remote received data that we previously sent. + // + // As of writing, linux seems to only confirm a route as reachable when + // forward progress is made which is indicated by an ACK that removes data + // from the retransmit queue. + if acked > 0 { + s.ep.route.ConfirmReachable() + } + ackLeft := acked originalOutstanding := s.outstanding for ackLeft > 0 { @@ -1435,7 +1421,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update the RACK fields if SACK is enabled. if s.ep.sackPermitted && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1464,7 +1450,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) if s.fr.last.LessThan(s.sndUna) { - s.state = Open + s.state = tcpip.Open + // Update RACK when we are exiting fast or RTO + // recovery as described in the RFC + // draft-ietf-tcpm-rack-08 Section-7.2 Step 4. + s.rc.exitRecovery() } } @@ -1488,6 +1478,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update RACK reorder window. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // * Upon receiving an ACK: + // * Step 4: Update RACK reordering window + s.rc.updateRACKReorderWindow(rcvdSeg) + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. if s.fr.active { @@ -1508,7 +1504,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // sendSegment sends the specified segment. -func (s *sender) sendSegment(seg *segment) *tcpip.Error { +func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() @@ -1539,7 +1535,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index f7aaee23f..ced3a9c58 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -21,13 +21,13 @@ package tcp_test import ( + "bytes" "fmt" "math" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" @@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in two shots. Packets will only be written at the // MTU size though. - half := data[:len(data)/2] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } - half = data[len(data)/2:] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 342eb5eb8..a6a26b705 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -15,11 +15,12 @@ package tcp_test import ( + "bytes" + "fmt" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -61,14 +62,16 @@ func TestRACKUpdate(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(maxPayload) + data := make([]byte, maxPayload) for i := range data { data[i] = byte(i) } // Write the data. xmitTime = time.Now() - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -114,13 +117,15 @@ func TestRACKDetectReorder(t *testing.T) { }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNumToVerify * maxPayload) + data := make([]byte, ackNumToVerify*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -141,17 +146,19 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(numPackets * maxPayload) + data := make([]byte, numPackets*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -528,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) { // ACK before the test completes. <-probeDone } + +func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + return + } + + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + return + } + + if state.Sender.RACKState.ReoWndIncr != 1 { + probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) + return + } + + if state.Sender.RACKState.ReoWndPersist > 0 { + probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) + return + } + probeDone <- nil + }) +} + +func TestRACKCheckReorderWindow(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan error) + const ackNumToVerify = 3 + addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) + + const numPackets = 7 + sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed packets [2-6] which indicates there is reordering + // in the connection. + bytesRead += 6 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + if err := <-probeDone; err != nil { + t.Fatalf("unexpected values for RACK variables: %v", err) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6635bb815..5024bc925 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -15,6 +15,7 @@ package tcp_test import ( + "bytes" "fmt" "log" "reflect" @@ -22,7 +23,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) { createConnectedWithSACKAndTS(c) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 93683b921..da2730e27 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "math" + "strings" "testing" "time" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -48,7 +48,7 @@ type endpointTester struct { } // CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { +func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { t.Helper() res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for receive to be notified. select { case <-notifyRead: @@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventHUp) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Close the connection, wait for completion. @@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) { // Call Connect again to retreive the handshake failure status // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAborted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + } } if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { @@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) + { + err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + } } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { @@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + { + err := c.EP.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } } // Receive SYN packet with our user supplied MSS. @@ -1347,10 +1359,9 @@ func TestTOSV4(t *testing.T) { testV4Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1396,10 +1407,9 @@ func TestTrafficClassV6(t *testing.T) { testV6Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1444,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("unexpected return value from Connect: %s", err) } @@ -1504,8 +1515,9 @@ func TestSynSent(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + err := c.EP.Connect(addr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) } // Receive SYN packet. @@ -1550,9 +1562,9 @@ func TestSynSent(t *testing.T) { ept := endpointTester{c.EP} if test.reset { - ept.CheckReadError(t, tcpip.ErrConnectionRefused) + ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) } else { - ept.CheckReadError(t, tcpip.ErrAborted) + ept.CheckReadError(t, &tcpip.ErrAborted{}) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1578,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1603,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1642,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, rcvBufSz) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1718,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1786,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1868,13 +1880,13 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1893,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -2054,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. @@ -2176,10 +2188,9 @@ func TestSimpleSend(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2217,10 +2228,9 @@ func TestZeroWindowSend(t *testing.T) { c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2285,10 +2295,9 @@ func TestScaledWindowConnect(t *testing.T) { }) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2317,10 +2326,9 @@ func TestNonScaledWindowConnect(t *testing.T) { c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2376,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2391,10 +2399,9 @@ func TestScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2450,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2465,10 +2472,9 @@ func TestNonScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2632,9 +2638,10 @@ func TestSegmentMerging(t *testing.T) { // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent // anymore packets from going out. + var r bytes.Reader for i := 0; i < tcp.InitialCwnd; i++ { - view := buffer.NewViewFromBytes([]byte{0}) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2644,8 +2651,8 @@ func TestSegmentMerging(t *testing.T) { var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2714,8 +2721,9 @@ func TestDelay(t *testing.T) { var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2761,8 +2769,9 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2845,8 +2854,9 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2894,10 +2904,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { data[i] = byte(i) } - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2963,7 +2972,7 @@ func TestSetTTL(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -2973,8 +2982,11 @@ func TestSetTTL(t *testing.T) { t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("unexpected return value from Connect: %s", err) + } } // Receive SYN packet. @@ -3034,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3090,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3115,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { defer c.Cleanup() s := c.Stack() - ch := make(chan *tcpip.Error, 1) + ch := make(chan tcpip.Error, 1) f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = r.CreateEndpoint(&c.WQ) ch <- err }) @@ -3146,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -3165,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventOut) defer c.WQ.EventUnregister(&we) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Receive SYN packet. @@ -3276,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { - case tcpip.ErrWouldBlock: + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case *tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } - case tcpip.ErrConnectionReset: + case *tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3328,9 +3344,11 @@ func TestSendOnResetConnection(t *testing.T) { time.Sleep(1 * time.Second) // Try to write. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + var r bytes.Reader + r.Reset(make([]byte, 10)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3352,7 +3370,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3409,7 +3429,9 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3458,7 +3480,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3595,8 +3619,10 @@ func TestFinWithNoPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3667,9 +3693,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // Write enough segments to fill the congestion window before ACK'ing // any of them. - view := buffer.NewView(10) + view := make([]byte, 10) + var r bytes.Reader for i := tcp.InitialCwnd; i > 0; i-- { - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3754,8 +3782,10 @@ func TestFinWithPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3781,7 +3811,8 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3841,8 +3872,10 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3879,7 +3912,8 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3985,8 +4019,10 @@ func scaledSendWindow(t *testing.T, scale uint8) { }) // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 65535) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4170,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4249,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + { + _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) + if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + } } } @@ -4263,7 +4302,7 @@ func TestReusePort(t *testing.T) { defer c.Cleanup() // First case, just an endpoint that was bound. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed; %s", err) @@ -4293,8 +4332,11 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } c.EP.Close() @@ -4351,9 +4393,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + s, err := ep.SocketOptions().GetSendBufferSize() if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + t.Fatalf("GetSendBufferSize failed: %s", err) } if int(s) != v { @@ -4459,9 +4501,7 @@ func TestMinMaxBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(149, true) checkSendBufferSize(t, ep, 300) @@ -4473,9 +4513,7 @@ func TestMinMaxBufferSizes(t *testing.T) { // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) - } + ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) @@ -4505,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -4529,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) { } } -func makeStack() (*stack.Stack, *tcpip.Error) { +func makeStack() (*stack.Stack, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -4599,8 +4637,11 @@ func TestSelfConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } <-notifyCh @@ -4610,9 +4651,9 @@ func TestSelfConnect(t *testing.T) { // Write something. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4752,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("Bind(%d) failed: %s", i, err) } } - want := tcpip.ErrConnectStarted + var want tcpip.Error = &tcpip.ErrConnectStarted{} if collides { - want = tcpip.ErrNoPortAvailable + want = &tcpip.ErrNoPortAvailable{} } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) @@ -4785,12 +4826,13 @@ func TestPathMTUDiscovery(t *testing.T) { // Send 3200 bytes of data. const writeSize = 3200 - data := buffer.NewView(writeSize) + data := make([]byte, writeSize) for i := range data { data[i] = byte(i) } - - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4878,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) { func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, tc := range testCases { @@ -4964,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { func TestEndpointSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, connected := range []bool{false, true} { @@ -4978,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5074,12 +5116,14 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5163,7 +5207,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -5270,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5313,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5330,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5342,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5358,7 +5402,9 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -5583,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5658,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.WQ.EventUnregister(&we) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5674,7 +5720,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { @@ -5692,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5733,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5749,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5763,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5807,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5882,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { }) newEP, _, err := c.EP.Accept(nil) - - if err != nil && err != tcpip.ErrWouldBlock { + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: t.Fatalf("Accept failed: %s", err) } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -5908,7 +5957,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5953,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6023,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6055,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } ept := endpointTester{ep} - ept.CheckReadError(t, tcpip.ErrNotConnected) + ept.CheckReadError(t, &tcpip.ErrNotConnected{}) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6075,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6091,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) + { + err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + } } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { @@ -6211,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } } @@ -6335,7 +6389,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } totalCopied += res.Count @@ -6527,7 +6581,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6646,7 +6700,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6753,7 +6807,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6843,7 +6897,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6917,7 +6971,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7067,7 +7121,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7103,10 +7157,10 @@ func TestTCPCloseWithData(t *testing.T) { // Now write a few bytes and then close the endpoint. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7204,8 +7258,10 @@ func TestTCPUserTimeout(t *testing.T) { } // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7256,7 +7312,7 @@ func TestTCPUserTimeout(t *testing.T) { ) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7300,7 +7356,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7339,7 +7395,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7494,8 +7550,9 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Send data. This should result in an acceptable endpoint. @@ -7552,8 +7609,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7675,13 +7733,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) { s := c.Stack() testCases := []struct { v int - err *tcpip.Error + err tcpip.Error }{ {int(tcpip.TCPTimeWaitReuseDisabled), nil}, {int(tcpip.TCPTimeWaitReuseGlobal), nil}, {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, } for _, tc := range testCases { diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index b65091c3c..5a9745ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ee55f030c..b1cb9a324 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 // only endpoint instead of a default dual stack socket. func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Create creates a TCP endpoint. func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) @@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) } // Receive SYN packet. @@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..31a5ddce9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -97,9 +97,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. state EndpointState @@ -111,8 +109,8 @@ type endpoint struct { multicastNICID tcpip.NICID portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // Values used to reserve a port or register a transport endpoint. // (which ever happens first). @@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) e.ops.SetMulticastLoop(true) + e.ops.SetSendBufferSize(32*1024, false /* notify */) // Override with stack defaults. - var ss stack.SendBufferSizeOption + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { - e.sndBufSizeMax = ss.Default + e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } var rs stack.ReceiveBufferSizeOption @@ -217,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -227,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -246,7 +244,7 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -284,7 +282,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -292,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -342,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -353,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: case StateConnected: @@ -361,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case StateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -391,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -417,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -438,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if err := e.LastError(); err != nil { return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -461,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -482,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -492,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } dst, netProto, err := e.checkV4MappedLocked(*to) @@ -511,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, tcpip.ErrBroadcastDisabled + return 0, &tcpip.ErrBroadcastDisabled{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. so := e.SocketOptions() if so.GetRecvError() { so.QueueLocalErr( - tcpip.ErrMessageTooLong, + &tcpip.ErrMessageTooLong{}, route.NetProto, header.UDPMaximumPacketSize, tcpip.FullAddress{ @@ -534,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc v, ) } - return 0, tcpip.ErrMessageTooLong + return 0, &tcpip.ErrMessageTooLong{} } ttl := e.ttl @@ -584,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: // Return not supported if the value is not disabling path // MTU discovery. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.MulticastTTLOption: @@ -632,25 +633,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvBufSizeMax = v e.mu.Unlock() return nil - case tcpip.SendBufferSizeOption: - // Make sure the send buffer size is within the min and max - // allowed. - var ss stack.SendBufferSizeOption - if err := e.stack.Option(&ss); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) - } - - if v < ss.Min { - v = ss.Min - } - if v > ss.Max { - v = ss.Max - } - - e.mu.Lock() - e.sndBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -661,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -683,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { if nic != 0 { if !e.stack.CheckNIC(nic) { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } else { nic = e.stack.CheckLocalAddress(0, netProto, addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } if e.BindNICID != 0 && e.BindNICID != nic { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.multicastNICID = nic @@ -701,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -717,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -726,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToInsert]; ok { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -737,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -752,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -761,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToRemove]; !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -777,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: e.mu.RLock() @@ -811,12 +793,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.SendBufferSizeOption: - e.mu.Lock() - v := e.sndBufSizeMax - e.mu.Unlock() - return v, nil - case tcpip.ReceiveBufferSizeOption: e.rcvMu.Lock() v := e.rcvBufSizeMax @@ -830,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -846,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.mu.Unlock() default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), Data: data, @@ -903,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -912,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (e *endpoint) Disconnect() *tcpip.Error { +func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -930,12 +906,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { - var err *tcpip.Error + var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { return err } @@ -950,7 +926,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.setEndpointState(StateInitial) } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -961,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { if addr.Port == 0 { // We don't support connecting to port zero. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.mu.Lock() @@ -981,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1023,7 +999,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { oldPortFlags := e.boundPortFlags - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { r.Release() return err @@ -1031,11 +1007,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id e.boundBindToDevice = btd + if e.route != nil { + // If the endpoint was already connected then make sure we release the + // previous route. + e.route.Release() + } e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID @@ -1051,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. if state := e.EndpointState(); state != StateBound && state != StateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } e.shutdownFlags |= flags @@ -1084,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) @@ -1104,7 +1085,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} @@ -1112,11 +1093,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, bindToDevice, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1140,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -1148,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, btd, err := e.registerWithStack(nicID, netProtos, id) + id, btd, err := e.registerWithStack(netProtos, id) if err != nil { return err } @@ -1170,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1186,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -1203,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.EndpointState() != StateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -1341,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1397,7 +1378,7 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt default: panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) } - e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) + e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 13b72dc88..21a6aa460 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). @@ -91,6 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) { defer e.mu.Unlock() e.stack = s + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { @@ -113,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) { netProto = header.IPv6ProtocolNumber } - var err *tcpip.Error + var err tcpip.Error if state == StateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { @@ -122,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } @@ -131,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 49e673d58..705ad1f64 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { } // CreateEndpoint creates a connected UDP endpoint for the session request. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { netHdr := r.pkt.Network() route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) if err != nil { @@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { + if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 91420edd3..427fdd0c9 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } @@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given udp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.UDP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e2123fe9..64c5298d3 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -353,7 +353,7 @@ func (c *testContext) cleanup() { func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { c.t.Helper() - var err *tcpip.Error + var err tcpip.Error c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { c.t.Fatal("NewEndpoint failed: ", err) @@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe var buf bytes.Buffer res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for data to become available. select { case <-ch: @@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) { t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(addr) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } } @@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) { defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) { for _, tt := range []struct { name string handleLocal bool - wantErr *tcpip.Error + wantErr tcpip.Error wantInvalidSource uint64 }{ {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ @@ -959,15 +965,16 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // testFailingWrite sends a packet of the given test flow into the UDP endpoint // and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { +func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - payload := buffer.View(newPayload()) - _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + var r bytes.Reader + r.Reset(newPayload()) + _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1007,8 +1014,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, } } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1089,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { testWrite(c, unicastV6) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) const want = 1 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) @@ -1110,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { testWrite(c, unicastV4in6) // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV4WriteOnV6Only(t *testing.T) { @@ -1120,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) { c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { @@ -1135,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV6WriteOnConnected(t *testing.T) { @@ -1183,8 +1192,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { writeOpts := tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -1192,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) } - if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + { + err := c.ep.LastError() + if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } } }) } @@ -2303,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) + testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) } -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { + switch err.(type) { case nil: want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: want.SendErrors.SendToNetworkFailed.IncrementBy(incr) @@ -2327,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE } } -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + case *tcpip.ErrClosedForReceive: want.ReadErrors.ReadClosed.IncrementBy(incr) default: c.t.Errorf("Endpoint error missing stats update err %v", err) @@ -2497,31 +2511,54 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } defer ep.Close() - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + var r bytes.Reader + data := []byte{1, 2, 3, 4} to := tcpip.FullAddress{ Addr: test.remoteAddr, Port: 80, } opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { + if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { + return nil + } + return &tcpip.ErrBroadcastDisabled{} + } if !test.requiresBroadcastOpt { expectedErrWithoutBcastOpt = nil } - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } ep.SocketOptions().SetBroadcast(true) - if n, err := ep.Write(data, opts); err != nil { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != nil { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + r.Reset(data) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } }) } diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 79db8895b..dc2571154 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er // Reader returns an io.Reader that reads from s. Reads beyond the end of s // return io.EOF. The preconditions that apply to s.CopyIn also apply to the // returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // Writer returns an io.Writer that writes to s. Writes beyond the end of s // return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also // apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when // attempting to write beyond the end of the IOSequence. var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") -type ioSequenceReadWriter struct { +// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence. +type IOSequenceReadWriter struct { ctx context.Context s IOSequence } // Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { +func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) { n, err := rw.s.CopyIn(rw.ctx, dst) rw.s = rw.s.DropFirst(n) if err == nil && rw.s.NumBytes() == 0 { @@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { return n, err } +// Len implements tcpip.Payloader. +func (rw *IOSequenceReadWriter) Len() int { + return int(rw.s.NumBytes()) +} + // Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { +func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) { n, err := rw.s.CopyOut(rw.ctx, src) rw.s = rw.s.DropFirst(n) if err == nil && n < len(src) { diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index a92ae046d..3bbf86534 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -52,7 +52,7 @@ func waitForProcessList(cont *Container, want []*control.Process) error { cb := func() error { got, err := cont.Processes() if err != nil { - err = fmt.Errorf("error getting process data from container: %v", err) + err = fmt.Errorf("error getting process data from container: %w", err) return &backoff.PermanentError{Err: err} } if !procListsEqual(got, want) { @@ -64,11 +64,30 @@ func waitForProcessList(cont *Container, want []*control.Process) error { return testutil.Poll(cb, 30*time.Second) } +// waitForProcess waits for the given process to show up in the container. +func waitForProcess(cont *Container, want *control.Process) error { + cb := func() error { + gots, err := cont.Processes() + if err != nil { + err = fmt.Errorf("error getting process data from container: %w", err) + return &backoff.PermanentError{Err: err} + } + for _, got := range gots { + if procEqual(got, want) { + return nil + } + } + return fmt.Errorf("container got process list: %s, want: %+v", procListToString(gots), want) + } + // Gives plenty of time as tests can run slow under --race. + return testutil.Poll(cb, 30*time.Second) +} + func waitForProcessCount(cont *Container, want int) error { cb := func() error { pss, err := cont.Processes() if err != nil { - err = fmt.Errorf("error getting process data from container: %v", err) + err = fmt.Errorf("error getting process data from container: %w", err) return &backoff.PermanentError{Err: err} } if got := len(pss); got != want { @@ -101,28 +120,32 @@ func procListsEqual(gots, wants []*control.Process) bool { return false } for i := range gots { - got := gots[i] - want := wants[i] - - if want.UID != math.MaxUint32 && want.UID != got.UID { - return false - } - if want.PID != -1 && want.PID != got.PID { - return false - } - if want.PPID != -1 && want.PPID != got.PPID { - return false - } - if len(want.TTY) != 0 && want.TTY != got.TTY { - return false - } - if len(want.Cmd) != 0 && want.Cmd != got.Cmd { + if !procEqual(gots[i], wants[i]) { return false } } return true } +func procEqual(got, want *control.Process) bool { + if want.UID != math.MaxUint32 && want.UID != got.UID { + return false + } + if want.PID != -1 && want.PID != got.PID { + return false + } + if want.PPID != -1 && want.PPID != got.PPID { + return false + } + if len(want.TTY) != 0 && want.TTY != got.TTY { + return false + } + if len(want.Cmd) != 0 && want.Cmd != got.Cmd { + return false + } + return true +} + type processBuilder struct { process control.Process } diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 044eec6fe..bc802e075 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -1708,12 +1708,9 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { t.Errorf("wait on child container: %v", err) } - // Wait for the root container to run. - expectedPL := []*control.Process{ - newProcessBuilder().Cmd("sh").Process(), - newProcessBuilder().Cmd("sleep").Process(), - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { + // Wait until after `env` has executed. + expectedProc := newProcessBuilder().Cmd("sleep").Process() + if err := waitForProcess(containers[0], expectedProc); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } @@ -1831,7 +1828,7 @@ func TestDuplicateEnvVariable(t *testing.T) { cmd1 := fmt.Sprintf("env > %q; sleep 1000", files[0].Name()) cmd2 := fmt.Sprintf("env > %q", files[1].Name()) cmdExec := fmt.Sprintf("env > %q", files[2].Name()) - testSpecs, ids := createSpecs([]string{"/bin/bash", "-c", cmd1}, []string{"/bin/bash", "-c", cmd2}) + testSpecs, ids := createSpecs([]string{"/bin/sh", "-c", cmd1}, []string{"/bin/sh", "-c", cmd2}) testSpecs[0].Process.Env = append(testSpecs[0].Process.Env, "VAR=foo", "VAR=bar") testSpecs[1].Process.Env = append(testSpecs[1].Process.Env, "VAR=foo", "VAR=bar") @@ -1841,12 +1838,9 @@ func TestDuplicateEnvVariable(t *testing.T) { } defer cleanup() - // Wait for the `env` from the root container to finish. - expectedPL := []*control.Process{ - newProcessBuilder().Cmd("bash").Process(), - newProcessBuilder().Cmd("sleep").Process(), - } - if err := waitForProcessList(containers[0], expectedPL); err != nil { + // Wait until after `env` has executed. + expectedProc := newProcessBuilder().Cmd("sleep").Process() + if err := waitForProcess(containers[0], expectedProc); err != nil { t.Errorf("failed to wait for sleep to start: %v", err) } if ws, err := containers[1].Wait(); err != nil { @@ -1856,8 +1850,8 @@ func TestDuplicateEnvVariable(t *testing.T) { } execArgs := &control.ExecArgs{ - Filename: "/bin/bash", - Argv: []string{"/bin/bash", "-c", cmdExec}, + Filename: "/bin/sh", + Argv: []string{"/bin/sh", "-c", cmdExec}, Envv: []string{"VAR=foo", "VAR=bar"}, } if ws, err := containers[0].executeSync(execArgs); err != nil || ws.ExitStatus() != 0 { diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index c56e1d4d0..3280b74fe 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -12,7 +12,6 @@ go_library( ], visibility = ["//runsc:__subpackages__"], deps = [ - "//pkg/abi/linux", "//pkg/cleanup", "//pkg/fd", "//pkg/log", diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index c3bba0973..cfa3796b1 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -31,7 +31,6 @@ import ( "strconv" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" @@ -49,13 +48,6 @@ const ( allowedOpenFlags = unix.O_TRUNC ) -var ( - // Remember the process uid/gid to skip chown calls when file owner/group - // doesn't need to be changed. - processUID = p9.UID(os.Getuid()) - processGID = p9.GID(os.Getgid()) -) - // join is equivalent to path.Join() but skips path.Clean() which is expensive. func join(parent, child string) string { if child == "." || child == ".." { @@ -374,7 +366,24 @@ func fstat(fd int) (unix.Stat_t, error) { } func fchown(fd int, uid p9.UID, gid p9.GID) error { - return unix.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) + return unix.Fchownat(fd, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) +} + +func setOwnerIfNeeded(fd int, uid p9.UID, gid p9.GID) (unix.Stat_t, error) { + stat, err := fstat(fd) + if err != nil { + return unix.Stat_t{}, err + } + + // Change ownership if not set accordinly. + if uint32(uid) != stat.Uid || uint32(gid) != stat.Gid { + if err := fchown(fd, uid, gid); err != nil { + return unix.Stat_t{}, err + } + stat.Uid = uint32(uid) + stat.Gid = uint32(gid) + } + return stat, nil } // Open implements p9.File. @@ -457,12 +466,7 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode, }) defer cu.Clean() - if uid != processUID || gid != processGID { - if err := fchown(child.FD(), uid, gid); err != nil { - return nil, nil, p9.QID{}, 0, extractErrno(err) - } - } - stat, err := fstat(child.FD()) + stat, err := setOwnerIfNeeded(child.FD(), uid, gid) if err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } @@ -505,12 +509,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) } defer f.Close() - if uid != processUID || gid != processGID { - if err := fchown(f.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) - } - } - stat, err := fstat(f.FD()) + stat, err := setOwnerIfNeeded(f.FD(), uid, gid) if err != nil { return p9.QID{}, extractErrno(err) } @@ -734,15 +733,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { if valid.ATime || valid.MTime { utimes := [2]unix.Timespec{ - {Sec: 0, Nsec: linux.UTIME_OMIT}, - {Sec: 0, Nsec: linux.UTIME_OMIT}, + {Sec: 0, Nsec: unix.UTIME_OMIT}, + {Sec: 0, Nsec: unix.UTIME_OMIT}, } if valid.ATime { if valid.ATimeNotSystemTime { utimes[0].Sec = int64(attr.ATimeSeconds) utimes[0].Nsec = int64(attr.ATimeNanoSeconds) } else { - utimes[0].Nsec = linux.UTIME_NOW + utimes[0].Nsec = unix.UTIME_NOW } } if valid.MTime { @@ -750,7 +749,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { utimes[1].Sec = int64(attr.MTimeSeconds) utimes[1].Nsec = int64(attr.MTimeNanoSeconds) } else { - utimes[1].Nsec = linux.UTIME_NOW + utimes[1].Nsec = unix.UTIME_NOW } } @@ -764,7 +763,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } defer unix.Close(parent) - if tErr := utimensat(parent, path.Base(l.hostPath), utimes, linux.AT_SYMLINK_NOFOLLOW); tErr != nil { + if tErr := utimensat(parent, path.Base(l.hostPath), utimes, unix.AT_SYMLINK_NOFOLLOW); tErr != nil { log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, tErr) err = extractErrno(tErr) } @@ -779,15 +778,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } if valid.UID || valid.GID { - uid := -1 + uid := p9.NoUID if valid.UID { - uid = int(attr.UID) + uid = attr.UID } - gid := -1 + gid := p9.NoGID if valid.GID { - gid = int(attr.GID) + gid = attr.GID } - if oErr := unix.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oErr != nil { + if oErr := fchown(f.FD(), uid, gid); oErr != nil { log.Debugf("SetAttr fchownat failed %q, err: %v", l.hostPath, oErr) err = extractErrno(oErr) } @@ -900,12 +899,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. } defer f.Close() - if uid != processUID || gid != processGID { - if err := fchown(f.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) - } - } - stat, err := fstat(f.FD()) + stat, err := setOwnerIfNeeded(f.FD(), uid, gid) if err != nil { return p9.QID{}, extractErrno(err) } @@ -921,7 +915,7 @@ func (l *localFile) Link(target p9.File, newName string) error { } targetFile := target.(*localFile) - if err := unix.Linkat(targetFile.file.FD(), "", l.file.FD(), newName, linux.AT_EMPTY_PATH); err != nil { + if err := unix.Linkat(targetFile.file.FD(), "", l.file.FD(), newName, unix.AT_EMPTY_PATH); err != nil { return extractErrno(err) } return nil @@ -959,12 +953,7 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid } defer child.Close() - if uid != processUID || gid != processGID { - if err := fchown(child.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) - } - } - stat, err := fstat(child.FD()) + stat, err := setOwnerIfNeeded(child.FD(), uid, gid) if err != nil { return p9.QID{}, extractErrno(err) } @@ -1113,7 +1102,8 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { // mappings, the app path may have fit in the sockaddr, but we can't // fit f.path in our sockaddr. We'd need to redirect through a shorter // path in order to actually connect to this socket. - if len(l.hostPath) > linux.UnixPathMax { + const UNIX_PATH_MAX = 108 // defined in afunix.h + if len(l.hostPath) > UNIX_PATH_MAX { return nil, unix.ECONNREFUSED } diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go index c46958185..29ebf8500 100644 --- a/runsc/fsgofer/fsgofer_amd64_unsafe.go +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go @@ -20,7 +20,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) @@ -39,7 +38,7 @@ func statAt(dirFd int, name string) (unix.Stat_t, error) { uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), - linux.AT_SYMLINK_NOFOLLOW, + unix.AT_SYMLINK_NOFOLLOW, 0, 0); errno != 0 { diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go index 491460718..9fd5d0871 100644 --- a/runsc/fsgofer/fsgofer_arm64_unsafe.go +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go @@ -20,7 +20,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) @@ -39,7 +38,7 @@ func statAt(dirFd int, name string) (unix.Stat_t, error) { uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), - linux.AT_SYMLINK_NOFOLLOW, + unix.AT_SYMLINK_NOFOLLOW, 0, 0); errno != 0 { diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index c5daebe5e..99ea9bd32 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -32,6 +32,9 @@ import ( "gvisor.dev/gvisor/runsc/specutils" ) +// Nodoby is the standard UID/GID for the nobody user/group. +const nobody = 65534 + var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} var ( @@ -281,6 +284,92 @@ func TestCreate(t *testing.T) { }) } +func checkIDs(f p9.File, uid, gid int) error { + _, _, stat, err := f.GetAttr(p9.AttrMask{UID: true, GID: true}) + if err != nil { + return fmt.Errorf("GetAttr() failed, err: %v", err) + } + if want := p9.UID(uid); stat.UID != want { + return fmt.Errorf("Wrong UID, want: %v, got: %v", want, stat.UID) + } + if want := p9.GID(gid); stat.GID != want { + return fmt.Errorf("Wrong GID, want: %v, got: %v", want, stat.GID) + } + return nil +} + +// TestCreateSetGID checks files/dirs/symlinks are created with the proper +// owner when the parent directory has setgid set, +func TestCreateSetGID(t *testing.T) { + if !specutils.HasCapabilities(capability.CAP_CHOWN) { + t.Skipf("Test requires CAP_CHOWN") + } + + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + // Change group and set setgid to the parent dir. + if err := unix.Chown(s.file.hostPath, os.Getuid(), nobody); err != nil { + t.Fatalf("Chown() failed: %v", err) + } + if err := unix.Chmod(s.file.hostPath, 02777); err != nil { + t.Fatalf("Chmod() failed: %v", err) + } + + t.Run("create", func(t *testing.T) { + _, l, _, _, err := s.file.Create("test", p9.ReadOnly, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) + if err != nil { + t.Fatalf("WriteAt() failed: %v", err) + } + defer l.Close() + if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil { + t.Error(err) + } + }) + + t.Run("mkdir", func(t *testing.T) { + _, err := s.file.Mkdir("test-dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) + if err != nil { + t.Fatalf("WriteAt() failed: %v", err) + } + _, l, err := s.file.Walk([]string{"test-dir"}) + if err != nil { + t.Fatalf("Walk() failed: %v", err) + } + defer l.Close() + if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil { + t.Error(err) + } + }) + + t.Run("symlink", func(t *testing.T) { + if _, err := s.file.Symlink("/some/target", "symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { + t.Fatalf("Symlink() failed: %v", err) + } + _, l, err := s.file.Walk([]string{"symlink"}) + if err != nil { + t.Fatalf("Walk() failed, err: %v", err) + } + defer l.Close() + if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil { + t.Error(err) + } + }) + + t.Run("mknod", func(t *testing.T) { + if _, err := s.file.Mknod("nod", p9.ModeRegular|0777, 0, 0, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { + t.Fatalf("Mknod() failed: %v", err) + } + _, l, err := s.file.Walk([]string{"nod"}) + if err != nil { + t.Fatalf("Walk() failed, err: %v", err) + } + defer l.Close() + if err := checkIDs(l, os.Getuid(), os.Getgid()); err != nil { + t.Error(err) + } + }) + }) +} + // TestReadWriteDup tests that a file opened in any mode can be dup'ed and // reopened in any other mode. func TestReadWriteDup(t *testing.T) { @@ -458,7 +547,7 @@ func TestSetAttrTime(t *testing.T) { } func TestSetAttrOwner(t *testing.T) { - if os.Getuid() != 0 { + if !specutils.HasCapabilities(capability.CAP_CHOWN) { t.Skipf("SetAttr(owner) test requires CAP_CHOWN, running as %d", os.Getuid()) } @@ -477,7 +566,7 @@ func TestSetAttrOwner(t *testing.T) { } func TestLink(t *testing.T) { - if os.Getuid() != 0 { + if !specutils.HasCapabilities(capability.CAP_DAC_READ_SEARCH) { t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid()) } runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) { @@ -995,7 +1084,6 @@ func BenchmarkCreateDiffOwner(b *testing.B) { files := make([]p9.File, 0, 500) fds := make([]*fd.FD, 0, 500) gid := p9.GID(os.Getgid()) - const nobody = 65534 b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/runsc/mitigate/BUILD b/runsc/mitigate/BUILD new file mode 100644 index 000000000..3b0342d18 --- /dev/null +++ b/runsc/mitigate/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "mitigate", + srcs = [ + "cpu.go", + "mitigate.go", + ], + deps = ["@in_gopkg_yaml_v2//:go_default_library"], +) + +go_test( + name = "mitigate_test", + size = "small", + srcs = ["cpu_test.go"], + library = ":mitigate", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], +) diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go new file mode 100644 index 000000000..113b98159 --- /dev/null +++ b/runsc/mitigate/cpu.go @@ -0,0 +1,235 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mitigate + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +const ( + // constants of coomm + meltdown = "cpu_meltdown" + l1tf = "l1tf" + mds = "mds" + swapgs = "swapgs" + taa = "taa" +) + +const ( + processorKey = "processor" + vendorIDKey = "vendor_id" + cpuFamilyKey = "cpu family" + modelKey = "model" + coreIDKey = "core id" + bugsKey = "bugs" +) + +// getCPUSet returns cpu structs from reading /proc/cpuinfo. +func getCPUSet(data string) ([]*cpu, error) { + // Each processor entry should start with the + // processor key. Find the beginings of each. + r := buildRegex(processorKey, `\d+`) + indices := r.FindAllStringIndex(data, -1) + if len(indices) < 1 { + return nil, fmt.Errorf("no cpus found for: %s", data) + } + + // Add the ending index for last entry. + indices = append(indices, []int{len(data), -1}) + + // Valid cpus are now defined by strings in between + // indexes (e.g. data[index[i], index[i+1]]). + // There should be len(indicies) - 1 CPUs + // since the last index is the end of the string. + var cpus = make([]*cpu, 0, len(indices)-1) + // Find each string that represents a CPU. These begin "processor". + for i := 1; i < len(indices); i++ { + start := indices[i-1][0] + end := indices[i][0] + // Parse the CPU entry, which should be between start/end. + c, err := getCPU(data[start:end]) + if err != nil { + return nil, err + } + cpus = append(cpus, c) + } + return cpus, nil +} + +// type cpu represents pertinent info about a cpu. +type cpu struct { + processorNumber int64 // the processor number of this CPU. + vendorID string // the vendorID of CPU (e.g. AuthenticAMD). + cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake). + model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake). + coreID int64 // This CPU's core id to match Hyperthread Pairs + bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field. +} + +// getCPU parses a CPU from a single cpu entry from /proc/cpuinfo. +func getCPU(data string) (*cpu, error) { + processor, err := parseProcessor(data) + if err != nil { + return nil, err + } + + vendorID, err := parseVendorID(data) + if err != nil { + return nil, err + } + + cpuFamily, err := parseCPUFamily(data) + if err != nil { + return nil, err + } + + model, err := parseModel(data) + if err != nil { + return nil, err + } + + coreID, err := parseCoreID(data) + if err != nil { + return nil, err + } + + bugs, err := parseBugs(data) + if err != nil { + return nil, err + } + + return &cpu{ + processorNumber: processor, + vendorID: vendorID, + cpuFamily: cpuFamily, + model: model, + coreID: coreID, + bugs: bugs, + }, nil +} + +// List of pertinent side channel vulnerablilites. +// For mds, see: https://www.kernel.org/doc/html/latest/admin-guide/hw-vuln/mds.html. +var vulnerabilities = []string{ + meltdown, + l1tf, + mds, + swapgs, + taa, +} + +// isVulnerable checks if a CPU is vulnerable to pertinent bugs. +func (c *cpu) isVulnerable() bool { + for _, bug := range vulnerabilities { + if _, ok := c.bugs[bug]; ok { + return true + } + } + return false +} + +// similarTo checks family/model/bugs fields for equality of two +// processors. +func (c *cpu) similarTo(other *cpu) bool { + if c.vendorID != other.vendorID { + return false + } + + if other.cpuFamily != c.cpuFamily { + return false + } + + if other.model != c.model { + return false + } + + if len(other.bugs) != len(c.bugs) { + return false + } + + for bug := range c.bugs { + if _, ok := other.bugs[bug]; !ok { + return false + } + } + return true +} + +// parseProcessor grabs the processor field from /proc/cpuinfo output. +func parseProcessor(data string) (int64, error) { + return parseIntegerResult(data, processorKey) +} + +// parseVendorID grabs the vendor_id field from /proc/cpuinfo output. +func parseVendorID(data string) (string, error) { + return parseRegex(data, vendorIDKey, `[\w\d]+`) +} + +// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output. +func parseCPUFamily(data string) (int64, error) { + return parseIntegerResult(data, cpuFamilyKey) +} + +// parseModel grabs the model field from /proc/cpuinfo output. +func parseModel(data string) (int64, error) { + return parseIntegerResult(data, modelKey) +} + +// parseCoreID parses the core id field. +func parseCoreID(data string) (int64, error) { + return parseIntegerResult(data, coreIDKey) +} + +// parseBugs grabs the bugs field from /proc/cpuinfo output. +func parseBugs(data string) (map[string]struct{}, error) { + result, err := parseRegex(data, bugsKey, `[\d\w\s]*`) + if err != nil { + return nil, err + } + bugs := strings.Split(result, " ") + ret := make(map[string]struct{}, len(bugs)) + for _, bug := range bugs { + ret[bug] = struct{}{} + } + return ret, nil +} + +// parseIntegerResult parses fields expecting an integer. +func parseIntegerResult(data, key string) (int64, error) { + result, err := parseRegex(data, key, `\d+`) + if err != nil { + return 0, err + } + return strconv.ParseInt(result, 0, 64) +} + +// buildRegex builds a regex for parsing each CPU field. +func buildRegex(key, match string) *regexp.Regexp { + reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key) + return regexp.MustCompile(reg) +} + +// parseRegex parses data with key inserted into a standard regex template. +func parseRegex(data, key, match string) (string, error) { + r := buildRegex(key, match) + matches := r.FindStringSubmatch(data) + if len(matches) < 2 { + return "", fmt.Errorf("failed to match key %s: %s", key, data) + } + return matches[1], nil +} diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go new file mode 100644 index 000000000..77b714a02 --- /dev/null +++ b/runsc/mitigate/cpu_test.go @@ -0,0 +1,368 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mitigate + +import ( + "io/ioutil" + "strings" + "testing" +) + +// CPU info for a Intel CascadeLake processor. Both Skylake and CascadeLake have +// the same family/model numbers, but with different bugs (e.g. skylake has +// cpu_meltdown). +var cascadeLake = &cpu{ + vendorID: "GenuineIntel", + cpuFamily: 6, + model: 85, + bugs: map[string]struct{}{ + "spectre_v1": struct{}{}, + "spectre_v2": struct{}{}, + "spec_store_bypass": struct{}{}, + mds: struct{}{}, + swapgs: struct{}{}, + taa: struct{}{}, + }, +} + +// TestGetCPU tests basic parsing of single CPU strings from reading +// /proc/cpuinfo. +func TestGetCPU(t *testing.T) { + data := `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +core id : 0 +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit +` + want := cpu{ + processorNumber: 0, + vendorID: "GenuineIntel", + cpuFamily: 6, + model: 85, + coreID: 0, + bugs: map[string]struct{}{ + "cpu_meltdown": struct{}{}, + "spectre_v1": struct{}{}, + "spectre_v2": struct{}{}, + "spec_store_bypass": struct{}{}, + "l1tf": struct{}{}, + "mds": struct{}{}, + "swapgs": struct{}{}, + "taa": struct{}{}, + "itlb_multihit": struct{}{}, + }, + } + + got, err := getCPU(data) + if err != nil { + t.Fatalf("getCpu failed with error: %v", err) + } + + if !want.similarTo(got) { + t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want) + } + + if !got.isVulnerable() { + t.Fatalf("Failed: cpu should be vulnerable.") + } +} + +func TestInvalid(t *testing.T) { + result, err := getCPUSet(`something not a processor`) + if err == nil { + t.Fatalf("getCPU set didn't return an error: %+v", result) + } + + if !strings.Contains(err.Error(), "no cpus") { + t.Fatalf("Incorrect error returned: %v", err) + } +} + +// TestCPUSet tests getting the right number of CPUs from +// parsing full output of /proc/cpuinfo. +func TestCPUSet(t *testing.T) { + data := `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management: + +processor : 1 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 1 +initial apicid : 1 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management: +` + cpuSet, err := getCPUSet(data) + if err != nil { + t.Fatalf("getCPUSet failed: %v", err) + } + + wantCPULen := 2 + if len(cpuSet) != wantCPULen { + t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet)) + } + + wantCPU := cpu{ + vendorID: "GenuineIntel", + cpuFamily: 6, + model: 63, + bugs: map[string]struct{}{ + "cpu_meltdown": struct{}{}, + "spectre_v1": struct{}{}, + "spectre_v2": struct{}{}, + "spec_store_bypass": struct{}{}, + "l1tf": struct{}{}, + "mds": struct{}{}, + "swapgs": struct{}{}, + }, + } + + for _, c := range cpuSet { + if !wantCPU.similarTo(c) { + t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU) + } + } +} + +// TestReadFile is a smoke test for parsing methods. +func TestReadFile(t *testing.T) { + data, err := ioutil.ReadFile("/proc/cpuinfo") + if err != nil { + t.Fatalf("Failed to read cpuinfo: %v", err) + } + + set, err := getCPUSet(string(data)) + if err != nil { + t.Fatalf("Failed to parse CPU data %v\n%s", err, data) + } + + if len(set) < 1 { + t.Fatalf("Failed to parse any CPUs: %d", len(set)) + } + + for _, c := range set { + t.Logf("CPU: %+v: %t", c, c.isVulnerable()) + } +} + +// TestVulnerable tests if the isVulnerable method is correct +// among known CPUs in GCP. +func TestVulnerable(t *testing.T) { + const haswell = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const skylake = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +model name : Intel(R) Xeon(R) CPU @ 2.00GHz +stepping : 3 +microcode : 0x1 +cpu MHz : 2000.180 +cache size : 39424 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa +bogomips : 4000.36 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const cascade = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +model name : Intel(R) Xeon(R) CPU +stepping : 7 +microcode : 0x1 +cpu MHz : 2800.198 +cache size : 33792 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 + ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu +lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr +efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r +tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a +rat avx512_vnni md_clear arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa +bogomips : 5600.39 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const amd = `processor : 0 +vendor_id : AuthenticAMD +cpu family : 23 +model : 49 +model name : AMD EPYC 7B12 +stepping : 0 +microcode : 0x1000065 +cpu MHz : 2250.000 +cache size : 512 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass +bogomips : 4500.00 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management:` + + for _, tc := range []struct { + name string + cpuString string + vulnerable bool + }{ + { + name: "haswell", + cpuString: haswell, + vulnerable: true, + }, { + name: "skylake", + cpuString: skylake, + vulnerable: true, + }, { + name: "cascadeLake", + cpuString: cascade, + vulnerable: false, + }, { + name: "amd", + cpuString: amd, + vulnerable: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + set, err := getCPUSet(tc.cpuString) + if err != nil { + t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString) + } + + if len(set) < 1 { + t.Fatalf("Returned empty cpu set: %v", set) + } + + for _, c := range set { + got := func() bool { + if cascadeLake.similarTo(c) { + return false + } + return c.isVulnerable() + }() + + if got != tc.vulnerable { + t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got) + } + } + }) + } +} diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go new file mode 100644 index 000000000..51d5449b6 --- /dev/null +++ b/runsc/mitigate/mitigate.go @@ -0,0 +1,20 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mitigate provides libraries for the mitigate command. The +// mitigate command mitigates side channel attacks such as MDS. Mitigate +// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online. In addition, +// the mitigate also handles computing available CPU in kubernetes kube_config +// files. +package mitigate diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go index 63f692ca9..300c21ceb 100644 --- a/test/fsstress/fsstress_test.go +++ b/test/fsstress/fsstress_test.go @@ -41,7 +41,7 @@ func fsstress(t *testing.T, dir string) { image = "basic/fsstress" ) seed := strconv.FormatUint(uint64(rand.Uint32()), 10) - args := []string{"-d", dir, "-n", operations, "-p", processes, "-seed", seed, "-X"} + args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"} t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, "")) out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...) if err != nil { diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 66453772a..ae4bba847 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -15,7 +15,9 @@ go_library( ], visibility = ["//test/iptables:__subpackages__"], deps = [ + "//pkg/binary", "//pkg/test/testutil", + "//pkg/usermem", ], ) diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 37a1a6694..c47660026 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -51,6 +51,12 @@ func init() { RegisterTestCase(FilterInputInvertDestination{}) RegisterTestCase(FilterInputSource{}) RegisterTestCase(FilterInputInvertSource{}) + RegisterTestCase(FilterInputInterfaceAccept{}) + RegisterTestCase(FilterInputInterfaceDrop{}) + RegisterTestCase(FilterInputInterface{}) + RegisterTestCase(FilterInputInterfaceBeginsWith{}) + RegisterTestCase(FilterInputInterfaceInvertDrop{}) + RegisterTestCase(FilterInputInterfaceInvertAccept{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -744,3 +750,195 @@ func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, i func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return sendUDPLoop(ctx, ip, acceptPort) } + +// FilterInputInterfaceAccept tests that packets are accepted from interface +// matching the iptables rule. +type FilterInputInterfaceAccept struct{ localCase } + +var _ TestCase = FilterInputInterfaceAccept{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceAccept) Name() string { + return "FilterInputInterfaceAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "ACCEPT"); err != nil { + return err + } + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceDrop tests that packets are dropped from interface +// matching the iptables rule. +type FilterInputInterfaceDrop struct{ localCase } + +var _ TestCase = FilterInputInterfaceDrop{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceDrop) Name() string { + return "FilterInputInterfaceDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", ifname, "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("packets should have been dropped, but got a packet") +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterface tests that packets are not dropped from interface which +// is not matching the interface name in the iptables rule. +type FilterInputInterface struct{ localCase } + +var _ TestCase = FilterInputInterface{} + +// Name implements TestCase.Name. +func (FilterInputInterface) Name() string { + return "FilterInputInterface" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "lo", "-j", "DROP"); err != nil { + return err + } + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %w", acceptPort, err) + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceBeginsWith tests that packets are dropped from an +// interface which begins with the given interface name. +type FilterInputInterfaceBeginsWith struct{ localCase } + +var _ TestCase = FilterInputInterfaceBeginsWith{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceBeginsWith) Name() string { + return "FilterInputInterfaceBeginsWith" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-i", "e+", "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("packets should have been dropped, but got a packet") +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInterfaceInvertDrop tests that we selectively drop packets from +// interface not matching the interface name. +type FilterInputInterfaceInvertDrop struct{ baseCase } + +var _ TestCase = FilterInputInterfaceInvertDrop{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceInvertDrop) Name() string { + return "FilterInputInterfaceInvertDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "DROP"); err != nil { + return err + } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("error reading: %w", err) + } + return fmt.Errorf("connection on port %d should not be accepted, but was accepted", acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err != nil { + var operr *net.OpError + if errors.As(err, &operr) && operr.Timeout() { + return nil + } + return fmt.Errorf("error connecting: %w", err) + } + return fmt.Errorf("connection destined to port %d should not be accepted, but was accepted", acceptPort) +} + +// FilterInputInterfaceInvertAccept tests that we can selectively accept packets +// not matching the specific incoming interface. +type FilterInputInterfaceInvertAccept struct{ baseCase } + +var _ TestCase = FilterInputInterfaceInvertAccept{} + +// Name implements TestCase.Name. +func (FilterInputInterfaceInvertAccept) Name() string { + return "FilterInputInterfaceInvertAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "!", "-i", "lo", "-j", "ACCEPT"); err != nil { + return err + } + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 4733146c0..ef92e3fff 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -392,6 +392,30 @@ func TestInputInvertSource(t *testing.T) { singleTest(t, FilterInputInvertSource{}) } +func TestInputInterfaceAccept(t *testing.T) { + singleTest(t, FilterInputInterfaceAccept{}) +} + +func TestInputInterfaceDrop(t *testing.T) { + singleTest(t, FilterInputInterfaceDrop{}) +} + +func TestInputInterface(t *testing.T) { + singleTest(t, FilterInputInterface{}) +} + +func TestInputInterfaceBeginsWith(t *testing.T) { + singleTest(t, FilterInputInterfaceBeginsWith{}) +} + +func TestInputInterfaceInvertDrop(t *testing.T) { + singleTest(t, FilterInputInterfaceInvertDrop{}) +} + +func TestInputInterfaceInvertAccept(t *testing.T) { + singleTest(t, FilterInputInterfaceInvertAccept{}) +} + func TestFilterAddrs(t *testing.T) { tcs := []struct { ipv6 bool @@ -424,3 +448,11 @@ func TestNATPreOriginalDst(t *testing.T) { func TestNATOutOriginalDst(t *testing.T) { singleTest(t, NATOutOriginalDst{}) } + +func TestNATPreRECVORIGDSTADDR(t *testing.T) { + singleTest(t, NATPreRECVORIGDSTADDR{}) +} + +func TestNATOutRECVORIGDSTADDR(t *testing.T) { + singleTest(t, NATOutRECVORIGDSTADDR{}) +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index a6ec5cca3..4cd770a65 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -171,7 +171,7 @@ func connectTCP(ctx context.Context, ip net.IP, port int) error { return err } if err := testutil.PollContext(ctx, callback); err != nil { - return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) + return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %w", port, err) } return nil diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 495241482..c3874240f 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -20,6 +20,9 @@ import ( "fmt" "net" "syscall" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/usermem" ) const redirectPort = 42 @@ -43,6 +46,8 @@ func init() { RegisterTestCase(NATLoopbackSkipsPrerouting{}) RegisterTestCase(NATPreOriginalDst{}) RegisterTestCase(NATOutOriginalDst{}) + RegisterTestCase(NATPreRECVORIGDSTADDR{}) + RegisterTestCase(NATOutRECVORIGDSTADDR{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -538,9 +543,9 @@ func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) } func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { - // The net package doesn't give guarantee access to the connection's + // The net package doesn't give guaranteed access to the connection's // underlying FD, and thus we cannot call getsockopt. We have to use - // traditional syscalls for SO_ORIGINAL_DST. + // traditional syscalls. // Create the listening socket, bind, listen, and accept. family := syscall.AF_INET @@ -609,36 +614,14 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. if err != nil { return err } - // The original destination could be any of our IPs. - for _, dst := range originalDsts { - want := syscall.RawSockaddrInet6{ - Family: syscall.AF_INET6, - Port: htons(dropPort), - } - copy(want.Addr[:], dst.To16()) - if got == want { - return nil - } - } - return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + return addrMatches6(got, originalDsts, dropPort) } got, err := originalDestination4(connFD) if err != nil { return err } - // The original destination could be any of our IPs. - for _, dst := range originalDsts { - want := syscall.RawSockaddrInet4{ - Family: syscall.AF_INET, - Port: htons(dropPort), - } - copy(want.Addr[:], dst.To4()) - if got == want { - return nil - } - } - return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + return addrMatches4(got, originalDsts, dropPort) } // loopbackTests runs an iptables rule and ensures that packets sent to @@ -662,3 +645,233 @@ func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) e return err } } + +// NATPreRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT +// address on the PREROUTING chain. +type NATPreRECVORIGDSTADDR struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreRECVORIGDSTADDR) Name() string { + return "NATPreRECVORIGDSTADDR" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + if err := recvWithRECVORIGDSTADDR(ctx, ipv6, nil, redirectPort); err != nil { + return err + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATOutRECVORIGDSTADDR tests that IP{V6}_RECVORIGDSTADDR gets the post-NAT +// address on the OUTPUT chain. +type NATOutRECVORIGDSTADDR struct{ containerCase } + +// Name implements TestCase.Name. +func (NATOutRECVORIGDSTADDR) Name() string { + return "NATOutRECVORIGDSTADDR" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRECVORIGDSTADDR) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + sendCh := make(chan error) + go func() { + // Packets will be sent to a non-container IP and redirected + // back to the container. + sendCh <- sendUDPLoop(ctx, ip, acceptPort) + }() + + expectedIP := &net.IP{127, 0, 0, 1} + if ipv6 { + expectedIP = &net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + } + if err := recvWithRECVORIGDSTADDR(ctx, ipv6, expectedIP, redirectPort); err != nil { + return err + } + + select { + case err := <-sendCh: + return err + default: + return nil + } +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRECVORIGDSTADDR) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP, port uint16) error { + // The net package doesn't give guaranteed access to a connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for IP_RECVORIGDSTADDR. + + // Create the listening socket. + var ( + family = syscall.AF_INET + level = syscall.SOL_IP + option = syscall.IP_RECVORIGDSTADDR + bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{ + Port: int(port), + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + ) + if ipv6 { + family = syscall.AF_INET6 + level = syscall.SOL_IPV6 + option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package. + bindAddr = &syscall.SockaddrInet6{ + Port: int(port), + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } + sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0) + if err != nil { + return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err) + } + defer syscall.Close(sockfd) + + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err) + } + + // Enable IP_RECVORIGDSTADDR. + if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil { + return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err) + } + + addrCh := make(chan interface{}) + errCh := make(chan error) + go func() { + var addr interface{} + var err error + if ipv6 { + addr, err = recvOrigDstAddr6(sockfd) + } else { + addr, err = recvOrigDstAddr4(sockfd) + } + if err != nil { + errCh <- err + } else { + addrCh <- addr + } + }() + + // Wait to receive a packet. + var addr interface{} + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case addr = <-addrCh: + } + + // Get a list of local IPs to verify that the packet now appears to have + // been sent to us. + var localAddrs []net.IP + if expectedDst != nil { + localAddrs = []net.IP{*expectedDst} + } else { + localAddrs, err = getInterfaceAddrs(ipv6) + if err != nil { + return fmt.Errorf("failed to get local interfaces: %w", err) + } + } + + // Verify that the address has the post-NAT port and address. + if ipv6 { + return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort) + } + return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort) +} + +func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) { + buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4) + if err != nil { + return syscall.RawSockaddrInet4{}, err + } + var addr syscall.RawSockaddrInet4 + binary.Unmarshal(buf, usermem.ByteOrder, &addr) + return addr, nil +} + +func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) { + buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6) + if err != nil { + return syscall.RawSockaddrInet6{}, err + } + var addr syscall.RawSockaddrInet6 + binary.Unmarshal(buf, usermem.ByteOrder, &addr) + return addr, nil +} + +func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) { + buf := make([]byte, 64) + oob := make([]byte, syscall.CmsgSpace(addrSize)) + for { + _, oobn, _, _, err := syscall.Recvmsg( + sockfd, + buf, // Message buffer. + oob, // Out-of-band buffer. + 0) // Flags. + if errors.Is(err, syscall.EINTR) { + continue + } + if err != nil { + return nil, fmt.Errorf("failed when calling Recvmsg: %w", err) + } + oob = oob[:oobn] + + // Parse out the control message. + msgs, err := syscall.ParseSocketControlMessage(oob) + if err != nil { + return nil, fmt.Errorf("failed to parse control message: %w", err) + } + return msgs[0].Data, nil + } +} + +func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error { + for _, wantAddr := range wantAddrs { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(port), + } + copy(want.Addr[:], wantAddr.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) +} + +func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error { + for _, wantAddr := range wantAddrs { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(port), + } + copy(want.Addr[:], wantAddr.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) +} diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD index ccf1c735f..0be14ca3e 100644 --- a/test/packetimpact/dut/BUILD +++ b/test/packetimpact/dut/BUILD @@ -14,6 +14,7 @@ cc_binary( grpcpp, "//test/packetimpact/proto:posix_server_cc_grpc_proto", "//test/packetimpact/proto:posix_server_cc_proto", + "@com_google_absl//absl/strings:str_format", ], ) @@ -24,5 +25,6 @@ cc_binary( grpcpp, "//test/packetimpact/proto:posix_server_cc_grpc_proto", "//test/packetimpact/proto:posix_server_cc_proto", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 4de8540f6..eba21df12 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -16,6 +16,7 @@ #include <getopt.h> #include <netdb.h> #include <netinet/in.h> +#include <poll.h> #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -30,6 +31,7 @@ #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" #include "include/grpcpp/server_context.h" +#include "absl/strings/str_format.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" @@ -256,6 +258,44 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } + ::grpc::Status Poll(::grpc::ServerContext *context, + const ::posix_server::PollRequest *request, + ::posix_server::PollResponse *response) override { + std::vector<struct pollfd> pfds; + pfds.reserve(request->pfds_size()); + for (const auto &pfd : request->pfds()) { + pfds.push_back({ + .fd = pfd.fd(), + .events = static_cast<short>(pfd.events()), + }); + } + int ret = ::poll(pfds.data(), pfds.size(), request->timeout_millis()); + + response->set_ret(ret); + if (ret < 0) { + response->set_errno_(errno); + } else { + // Only pollfds that have non-empty revents are returned, the client can't + // rely on indexes of the request array. + for (const auto &pfd : pfds) { + if (pfd.revents) { + auto *proto_pfd = response->add_pfds(); + proto_pfd->set_fd(pfd.fd); + proto_pfd->set_events(pfd.revents); + } + } + if (int ready = response->pfds_size(); ret != ready) { + return ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + absl::StrFormat( + "poll's return value(%d) doesn't match the number of " + "file descriptors that are actually ready(%d)", + ret, ready)); + } + } + return ::grpc::Status::OK; + } + ::grpc::Status Send(::grpc::ServerContext *context, const ::posix_server::SendRequest *request, ::posix_server::SendResponse *response) override { diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index f32ed54ef..b4c68764a 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -142,6 +142,25 @@ message ListenResponse { int32 errno_ = 2; // "errno" may fail to compile in c++. } +// The events field is overloaded: when used for request, it is copied into the +// events field of posix struct pollfd; when used for response, it is filled by +// the revents field from the posix struct pollfd. +message PollFd { + int32 fd = 1; + uint32 events = 2; +} + +message PollRequest { + repeated PollFd pfds = 1; + int32 timeout_millis = 2; +} + +message PollResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + repeated PollFd pfds = 3; +} + message SendRequest { int32 sockfd = 1; bytes buf = 2; @@ -226,6 +245,10 @@ service Posix { rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse); // Call listen() on the DUT. rpc Listen(ListenRequest) returns (ListenResponse); + // Call poll() on the DUT. Only pollfds that have non-empty revents are + // returned, the only way to tie the response back to the original request + // is using the fd number. + rpc Poll(PollRequest) returns (PollResponse); // Call send() on the DUT. rpc Send(SendRequest) returns (SendResponse); // Call sendto() on the DUT. diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index c6c95546a..a7c46781f 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -175,9 +175,6 @@ ALL_TESTS = [ name = "udp_discard_mcast_source_addr", ), PacketimpactTestInfo( - name = "udp_recv_mcast_bcast", - ), - PacketimpactTestInfo( name = "udp_any_addr_recv_unicast", ), PacketimpactTestInfo( @@ -277,6 +274,13 @@ ALL_TESTS = [ PacketimpactTestInfo( name = "tcp_rcv_buf_space", ), + PacketimpactTestInfo( + name = "tcp_rack", + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "tcp_info", + ), ] def validate_all_tests(): diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 576577310..1453ac232 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -1008,6 +1008,13 @@ func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { return sa } +// SrcPort returns the source port of this connection. +func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 { + t.Helper() + + return *conn.udpState(t).out.SrcPort +} + // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { @@ -1024,6 +1031,11 @@ func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ... (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + (*Connection)(conn).send(t, overrideLayers, additionalLayers...) +} + // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { @@ -1053,6 +1065,14 @@ func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout return (*Connection)(conn).ExpectFrame(t, expected, timeout) } +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv4) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) +} + // Close frees associated resources held by the UDPIPv4 connection. func (conn *UDPIPv4) Close(t *testing.T) { t.Helper() @@ -1136,6 +1156,13 @@ func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 return sa } +// SrcPort returns the source port of this connection. +func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 { + t.Helper() + + return *conn.udpState(t).out.SrcPort +} + // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { @@ -1152,6 +1179,11 @@ func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers . (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + (*Connection)(conn).send(t, overrideLayers, additionalLayers...) +} + // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { @@ -1181,6 +1213,14 @@ func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout return (*Connection)(conn).ExpectFrame(t, expected, timeout) } +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv6) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) +} + // Close frees associated resources held by the UDPIPv6 connection. func (conn *UDPIPv6) Close(t *testing.T) { t.Helper() diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 66a0255b8..aedcf6013 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -486,6 +486,56 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backl return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } +// Poll calls poll on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over error handling is needed, use PollWithErrno. +// Only pollfds with non-empty revents are returned, the only way to tie the +// response back to the original request is using the fd number. +func (dut *DUT) Poll(t *testing.T, pfds []unix.PollFd, timeout time.Duration) []unix.PollFd { + t.Helper() + + ctx := context.Background() + var cancel context.CancelFunc + if timeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, timeout+RPCTimeout) + defer cancel() + } + ret, result, err := dut.PollWithErrno(ctx, t, pfds, timeout) + if ret < 0 { + t.Fatalf("failed to poll: %s", err) + } + return result +} + +// PollWithErrno calls poll on the DUT. +func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.PollFd, timeout time.Duration) (int32, []unix.PollFd, error) { + t.Helper() + + req := pb.PollRequest{ + TimeoutMillis: int32(timeout.Milliseconds()), + } + for _, pfd := range pfds { + req.Pfds = append(req.Pfds, &pb.PollFd{ + Fd: pfd.Fd, + Events: uint32(pfd.Events), + }) + } + resp, err := dut.posixServer.Poll(ctx, &req) + if err != nil { + t.Fatalf("failed to call Poll: %s", err) + } + if ret, npfds := resp.GetRet(), len(resp.GetPfds()); ret >= 0 && int(ret) != npfds { + t.Fatalf("nonsensical poll response: ret(%d) != len(pfds)(%d)", ret, npfds) + } + var result []unix.PollFd + for _, protoPfd := range resp.GetPfds() { + result = append(result, unix.PollFd{ + Fd: protoPfd.GetFd(), + Revents: int16(protoPfd.GetEvents()), + }) + } + return resp.GetRet(), result, syscall.Errno(resp.GetErrno_()) +} + // Send calls send on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendWithErrno. @@ -544,7 +594,7 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, } resp, err := dut.posixServer.SendTo(ctx, &req) if err != nil { - t.Fatalf("faled to call SendTo: %s", err) + t.Fatalf("failed to call SendTo: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index b1b3c578b..baa3ae5e9 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -38,18 +38,6 @@ packetimpact_testbench( ) packetimpact_testbench( - name = "udp_recv_mcast_bcast", - srcs = ["udp_recv_mcast_bcast_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//test/packetimpact/testbench", - "@com_github_google_go_cmp//cmp:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -packetimpact_testbench( name = "udp_any_addr_recv_unicast", srcs = ["udp_any_addr_recv_unicast_test.go"], deps = [ @@ -340,6 +328,8 @@ packetimpact_testbench( name = "udp_send_recv_dgram", srcs = ["udp_send_recv_dgram_test.go"], deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", "//test/packetimpact/testbench", "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", @@ -376,6 +366,33 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_rack", + srcs = ["tcp_rack_test.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//pkg/usermem", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( + name = "tcp_info", + srcs = ["tcp_info_test.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/tcpip/header", + "//pkg/usermem", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go index d2203082d..ee050e2c6 100644 --- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go @@ -45,8 +45,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { ipPayloadLen int fragments []fragmentInfo expectReply bool - skip bool - skipReason string }{ { description: "basic reassembly", @@ -78,8 +76,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 7, more: 0}, }, expectReply: true, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, { description: "fragment subset", @@ -91,8 +87,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 8, more: 0}, }, expectReply: true, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, { description: "fragment overlap", @@ -104,16 +98,10 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 9, more: 0}, }, expectReply: false, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, } for _, test := range tests { - if test.skip { - t.Skip("%s test skipped: %s", test.description, test.skipReason) - continue - } t.Run(test.description, func(t *testing.T) { dut := testbench.NewDUT(t) conn := dut.Net.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{}) diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go new file mode 100644 index 000000000..b66e8f609 --- /dev/null +++ b/test/packetimpact/tests/tcp_info_test.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_info_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +func TestTCPInfo(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + // Send and receive sample data. + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + dut.Send(t, acceptFD, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + info := linux.TCPInfo{} + infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) + + rtt := time.Duration(info.RTT) * time.Microsecond + rttvar := time.Duration(info.RTTVar) * time.Microsecond + rto := time.Duration(info.RTO) * time.Microsecond + if rtt == 0 || rttvar == 0 || rto == 0 { + t.Errorf("expected rtt(%v), rttvar(%v) and rto(%v) to be greater than zero", rtt, rttvar, rto) + } + if info.ReordSeen != 0 { + t.Errorf("expected the connection to not have any reordering, got: %v want: 0", info.ReordSeen) + } + if info.SndCwnd == 0 { + t.Errorf("expected send congestion window to be greater than zero") + } + if info.CaState != linux.TCP_CA_Open { + t.Errorf("expected the connection to be in open state, got: %v want: %v", info.CaState, linux.TCP_CA_Open) + } + + if t.Failed() { + t.FailNow() + } + + // Check the congestion control state and send congestion window after + // retransmission timeout. + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + dut.Send(t, acceptFD, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + // Expect retransmission of the packet within 1.5*RTO. + timeout := time.Duration(float64(info.RTO)*1.5) * time.Microsecond + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + info = linux.TCPInfo{} + infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) + if info.CaState != linux.TCP_CA_Loss { + t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss) + } + if info.SndCwnd != 1 { + t.Errorf("expected send congestion window to be 1, got: %v %v", info.SndCwnd) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index f0af5352d..c874a8912 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -34,6 +34,21 @@ func TestTcpNoAcceptCloseReset(t *testing.T) { conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) conn.Connect(t) defer conn.Close(t) + // We need to wait for POLLIN event on listenFd to know the connection is + // established. Otherwise there could be a race when we issue the Close + // command prior to the DUT receiving the last ack of the handshake and + // it will only respond RST instead of RST+ACK. + timeout := time.Second + pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, timeout) + if n := len(pfds); n != 1 { + t.Fatalf("poll returned %d ready file descriptors, expected 1", n) + } + if readyFd := pfds[0].Fd; readyFd != listenFd { + t.Fatalf("poll returned an fd %d that was not requested (%d)", readyFd, listenFd) + } + if got, want := pfds[0].Revents, int16(unix.POLLIN); got&want == 0 { + t.Fatalf("poll returned no events in our interest, got: %#b, want: %#b", got, want) + } dut.Close(t, listenFd) if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go new file mode 100644 index 000000000..0a2381c97 --- /dev/null +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -0,0 +1,221 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_rack_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +const ( + // payloadSize is the size used to send packets. + payloadSize = header.TCPDefaultMSS + + // simulatedRTT is the time delay between packets sent and acked to + // increase the RTT. + simulatedRTT = 30 * time.Millisecond + + // numPktsForRTT is the number of packets sent and acked to establish + // RTT. + numPktsForRTT = 10 +) + +func createSACKConnection(t *testing.T) (testbench.DUT, testbench.TCPIPv4, int32, int32) { + dut := testbench.NewDUT(t) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + + // Enable SACK. + opts := make([]byte, 40) + optsOff := 0 + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeSACKPermittedOption(opts[optsOff:]) + + conn.ConnectWithOptions(t, opts[:optsOff]) + acceptFd, _ := dut.Accept(t, listenFd) + return dut, conn, acceptFd, listenFd +} + +func closeSACKConnection(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, acceptFd, listenFd int32) { + dut.Close(t, acceptFd) + dut.Close(t, listenFd) + conn.Close(t) +} + +func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto time.Duration) { + info := linux.TCPInfo{} + ret := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + binary.Unmarshal(ret, usermem.ByteOrder, &info) + return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond +} + +func sendAndReceive(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, numPkts int, acceptFd int32, sendACK bool) time.Time { + seqNum1 := *conn.RemoteSeqNum(t) + payload := make([]byte, payloadSize) + var lastSent time.Time + for i, sn := 0, seqNum1; i < numPkts; i++ { + lastSent = time.Now() + dut.Send(t, acceptFd, payload, 0) + gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, time.Second) + if err != nil { + t.Fatalf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Fatalf("#%d: expected a packet within a second but got none", i+1) + } + sn.UpdateForward(seqnum.Size(payloadSize)) + + if sendACK { + time.Sleep(simulatedRTT) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))}) + } + } + return lastSent +} + +// TestRACKTLPAllPacketsLost tests TLP when an entire flight of data is lost. +func TestRACKTLPAllPacketsLost(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 5 + lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + // Probe Timeout (PTO) should be two times RTT. Check that the last + // packet is retransmitted after probe timeout. + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + pto := rtt * 2 + // We expect the 5th packet (the last unacknowledged packet) to be + // retransmitted. + tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize)) + if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s %v %v", err, rtt, pto) + } + diff := time.Now().Sub(lastSent) + if diff < pto { + t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto) + } + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} + +// TestRACKTLPLost tests TLP when there are tail losses. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.4 +func TestRACKTLPLost(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 10 + lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + // Cumulative ACK for #[1-5] packets. + ackNum := seqNum1.Add(seqnum.Size(6 * payloadSize)) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))}) + + // Probe Timeout (PTO) should be two times RTT. Check that the last + // packet is retransmitted after probe timeout. + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + pto := rtt * 2 + // We expect the 10th packet (the last unacknowledged packet) to be + // retransmitted. + tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize)) + if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + diff := time.Now().Sub(lastSent) + if diff < pto { + t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto) + } + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} + +// TestRACKTLPWithSACK tests TLP by acknowledging out of order packets. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-8.1 +func TestRACKTLPWithSACK(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 3 + lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + // SACK for #2 packet. + sackBlock := make([]byte, 40) + start := seqNum1.Add(seqnum.Size(payloadSize)) + end := start.Add(seqnum.Size(payloadSize)) + sbOff := 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock[sbOff:]) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // RACK marks #1 packet as lost and retransmits it. + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // ACK for #1 packet. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))}) + + // Probe Timeout (PTO) should be two times RTT. TLP will trigger for #3 + // packet. RACK adds an additional timeout of 200ms if the number of + // outstanding packets is equal to 1. + rtt, rto := getRTTAndRTO(t, dut, acceptFd) + pto := rtt*2 + (200 * time.Millisecond) + if rto < pto { + pto = rto + } + // We expect the 3rd packet (the last unacknowledged packet) to be + // retransmitted. + tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize)) + if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + diff := time.Now().Sub(lastSent) + if diff < pto { + t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto) + } + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 1ab9ee1b2..b15b8fc25 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -66,33 +66,39 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) - startProbeDuration := time.Second - current := startProbeDuration - first := time.Now() // Ask the dut to send out data. dut.Send(t, acceptFd, sampleData, 0) + + var prev time.Duration // Expect the dut to keep the connection alive as long as the remote is // acknowledging the zero-window probes. - for i := 0; i < 5; i++ { + for i := 1; i <= 5; i++ { start := time.Now() // Expect zero-window probe with a timeout which is a function of the typical // first retransmission time. The retransmission times is supposed to // exponentially increase. - if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Duration(i)*time.Second); err != nil { t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) } - if i == 0 { - startProbeDuration = time.Now().Sub(first) - current = 2 * startProbeDuration + if i == 1 { + // Skip the first probe as computing transmit time for that is + // non-deterministic because of the arbitrary time taken for + // the dut to receive a send command and issue a send. continue } - // Check if the probes came at exponentially increasing intervals. - if got, want := time.Since(start), current-startProbeDuration; got < want { + + // Check if the time taken to receive the probe from the dut is + // increasing exponentially. To avoid flakes, use a correction + // factor for the expected duration which accounts for any + // scheduling non-determinism. + const timeCorrection = 200 * time.Millisecond + got := time.Since(start) + if want := (2 * prev) - timeCorrection; prev != 0 && got < want { t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) } + prev = got // Acknowledge the zero-window probes from the dut. conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - current *= 2 } // Advertize non-zero window. conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go deleted file mode 100644 index b29c07825..000000000 --- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_recv_mcast_bcast_test - -import ( - "context" - "flag" - "fmt" - "net" - "syscall" - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.Initialize(flag.CommandLine) -} - -func TestUDPRecvMcastBcast(t *testing.T) { - dut := testbench.NewDUT(t) - subnetBcastAddr := broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)) - for _, v := range []struct { - bound, to net.IP - }{ - {bound: net.IPv4zero, to: subnetBcastAddr}, - {bound: net.IPv4zero, to: net.IPv4bcast}, - {bound: net.IPv4zero, to: net.IPv4allsys}, - - {bound: subnetBcastAddr, to: subnetBcastAddr}, - - // FIXME(gvisor.dev/issue/4896): Previously by the time subnetBcastAddr is - // created, IPv4PrefixLength is still 0 because genPseudoFlags is not called - // yet, it was only called in NewDUT, so the test didn't do what the author - // original intended to and becomes failing because we process all flags at - // the very beginning. - // - // {bound: subnetBcastAddr, to: net.IPv4bcast}, - - {bound: net.IPv4bcast, to: net.IPv4bcast}, - {bound: net.IPv4allsys, to: net.IPv4allsys}, - } { - t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) { - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound) - defer dut.Close(t, boundFD) - conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close(t) - - payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) - conn.SendIP( - t, - testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))}, - testbench.UDP{}, - &testbench.Payload{Bytes: payload}, - ) - got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) - } - }) - } -} - -func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) { - dut := testbench.NewDUT(t) - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv4) - dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0}) - defer dut.Close(t, boundFD) - conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close(t) - - for _, to := range []net.IP{ - broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)), - net.IPv4(255, 255, 255, 255), - net.IPv4(224, 0, 0, 1), - } { - t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) { - payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) - conn.SendIP( - t, - testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))}, - testbench.UDP{}, - &testbench.Payload{Bytes: payload}, - ) - ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0) - if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { - t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) - } - }) - } -} - -func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { - result := make(net.IP, net.IPv4len) - ip4 := ip.To4() - for i := range ip4 { - result[i] = ip4[i] | ^mask[i] - } - return result -} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 7ee2c8014..6e45cb143 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -15,13 +15,18 @@ package udp_send_recv_dgram_test import ( + "context" "flag" + "fmt" "net" + "syscall" "testing" "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -30,74 +35,295 @@ func init() { } type udpConn interface { - Send(*testing.T, testbench.UDP, ...testbench.Layer) - ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) - Drain(*testing.T) + SrcPort(*testing.T) uint16 + SendFrame(*testing.T, testbench.Layers, ...testbench.Layer) + ExpectFrame(*testing.T, testbench.Layers, time.Duration) (testbench.Layers, error) Close(*testing.T) } +type testCase struct { + bindTo, sendTo net.IP + sendToBroadcast, bindToDevice, expectData bool +} + func TestUDP(t *testing.T) { dut := testbench.NewDUT(t) + subnetBcast := func() net.IP { + subnet := (&tcpip.AddressWithPrefix{ + Address: tcpip.Address(dut.Net.RemoteIPv4.To4()), + PrefixLen: dut.Net.IPv4PrefixLength, + }).Subnet() + return net.IP(subnet.Broadcast()) + }() - for _, isIPv4 := range []bool{true, false} { - ipVersionName := "IPv6" - if isIPv4 { - ipVersionName = "IPv4" - } - t.Run(ipVersionName, func(t *testing.T) { - var addr net.IP - if isIPv4 { - addr = dut.Net.RemoteIPv4 - } else { - addr = dut.Net.RemoteIPv6 + t.Run("Send", func(t *testing.T) { + var testCases []testCase + // Test every valid combination of bound/unbound, broadcast/multicast/unicast + // bound/destination address, and bound/not-bound to device. + for _, bindTo := range []net.IP{ + nil, // Do not bind. + net.IPv4zero, + net.IPv4bcast, + net.IPv4allsys, + subnetBcast, + dut.Net.RemoteIPv4, + dut.Net.RemoteIPv6, + } { + for _, sendTo := range []net.IP{ + net.IPv4bcast, + net.IPv4allsys, + subnetBcast, + dut.Net.LocalIPv4, + dut.Net.LocalIPv6, + } { + // Cannot send to an IPv4 address from a socket bound to IPv6 (except for IPv4-mapped IPv6), + // and viceversa. + if bindTo != nil && ((bindTo.To4() == nil) != (sendTo.To4() == nil)) { + continue + } + for _, bindToDevice := range []bool{true, false} { + expectData := true + switch { + case bindTo.Equal(dut.Net.RemoteIPv4): + // If we're explicitly bound to an interface's unicast address, + // packets are always sent on that interface. + case bindToDevice: + // If we're explicitly bound to an interface, packets are always + // sent on that interface. + case !sendTo.Equal(net.IPv4bcast) && !sendTo.IsMulticast(): + // If we're not sending to limited broadcast or multicast, the route table + // will be consulted and packets will be sent on the correct interface. + default: + expectData = false + } + testCases = append( + testCases, + testCase{ + bindTo: bindTo, + sendTo: sendTo, + sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: expectData, + }, + ) + } } - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, addr) - defer dut.Close(t, boundFD) - - var conn udpConn - var localAddr unix.Sockaddr - if isIPv4 { - v4Conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v4Conn.LocalAddr(t) - conn = &v4Conn - } else { - v6Conn := dut.Net.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v6Conn.LocalAddr(t, dut.Net.RemoteDevID) - conn = &v6Conn - } - defer conn.Close(t) - - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. + } + for _, tc := range testCases { + boundTestCaseName := "unbound" + if tc.bindTo != nil { + boundTestCaseName = fmt.Sprintf("bindTo=%s", tc.bindTo) } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("Send", func(t *testing.T) { - conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) - got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + t.Run(fmt.Sprintf("%s/sendTo=%s/bindToDevice=%t/expectData=%t", boundTestCaseName, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) { + runTestCase( + t, + dut, + tc, + func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) { + var destSockaddr unix.Sockaddr + if sendTo4 := tc.sendTo.To4(); sendTo4 != nil { + addr := unix.SockaddrInet4{ + Port: int(conn.SrcPort(t)), + } + copy(addr.Addr[:], sendTo4) + destSockaddr = &addr + } else { + addr := unix.SockaddrInet6{ + Port: int(conn.SrcPort(t)), + ZoneId: dut.Net.RemoteDevID, + } + copy(addr.Addr[:], tc.sendTo.To16()) + destSockaddr = &addr } - }) - t.Run("Recv", func(t *testing.T) { - conn.Drain(t) - if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { - t.Fatalf("short write got: %d, want: %d", got, want) + if got, want := dut.SendTo(t, socketFD, payload, 0, destSockaddr), len(payload); int(got) != want { + t.Fatalf("got dut.SendTo = %d, want %d", got, want) } - if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + layers = append(layers, &testbench.Payload{ + Bytes: payload, + }) + _, err := conn.ExpectFrame(t, layers, time.Second) + + if !tc.expectData && err == nil { + t.Fatal("received unexpected packet, socket is not bound to device") + } + if err != nil && tc.expectData { t.Fatal(err) } - }) - }) + }, + ) + }) + } + }) + t.Run("Recv", func(t *testing.T) { + // Test every valid combination of broadcast/multicast/unicast + // bound/destination address, and bound/not-bound to device. + var testCases []testCase + for _, addr := range []net.IP{ + net.IPv4bcast, + net.IPv4allsys, + dut.Net.RemoteIPv4, + dut.Net.RemoteIPv6, + } { + for _, bindToDevice := range []bool{true, false} { + testCases = append( + testCases, + testCase{ + bindTo: addr, + sendTo: addr, + sendToBroadcast: addr.Equal(subnetBcast) || addr.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: true, + }, + ) } - }) + } + for _, bindTo := range []net.IP{ + net.IPv4zero, + subnetBcast, + dut.Net.RemoteIPv4, + } { + for _, sendTo := range []net.IP{ + subnetBcast, + net.IPv4bcast, + net.IPv4allsys, + } { + // TODO(gvisor.dev/issue/4896): Add bindTo=subnetBcast/sendTo=IPv4bcast + // and bindTo=subnetBcast/sendTo=IPv4allsys test cases. + if bindTo.Equal(subnetBcast) && (sendTo.Equal(net.IPv4bcast) || sendTo.IsMulticast()) { + continue + } + // Expect that a socket bound to a unicast address does not receive + // packets sent to an address other than the bound unicast address. + // + // Note: we cannot use net.IP.IsGlobalUnicast to test this condition + // because IsGlobalUnicast does not check whether the address is the + // subnet broadcast, and returns true in that case. + expectData := !bindTo.Equal(dut.Net.RemoteIPv4) || sendTo.Equal(dut.Net.RemoteIPv4) + for _, bindToDevice := range []bool{true, false} { + testCases = append( + testCases, + testCase{ + bindTo: bindTo, + sendTo: sendTo, + sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: expectData, + }, + ) + } + } + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("bindTo=%s/sendTo=%s/bindToDevice=%t/expectData=%t", tc.bindTo, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) { + runTestCase( + t, + dut, + tc, + func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) { + conn.SendFrame(t, layers, &testbench.Payload{Bytes: payload}) + + if tc.expectData { + got, want := dut.Recv(t, socketFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + } else { + // Expected receive error, set a short receive timeout. + dut.SetSockOptTimeval( + t, + socketFD, + unix.SOL_SOCKET, + unix.SO_RCVTIMEO, + &unix.Timeval{ + Sec: 1, + Usec: 0, + }, + ) + ret, recvPayload, errno := dut.RecvWithErrno(context.Background(), t, socketFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, recvPayload, errno) + } + } + }, + ) + }) + } + }) +} + +func runTestCase( + t *testing.T, + dut testbench.DUT, + tc testCase, + runTc func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers), +) { + var ( + socketFD int32 + outgoingUDP, incomingUDP testbench.UDP + ) + if tc.bindTo != nil { + var remotePort uint16 + socketFD, remotePort = dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, tc.bindTo) + outgoingUDP.DstPort = &remotePort + incomingUDP.SrcPort = &remotePort + } else { + // An unbound socket will auto-bind to INNADDR_ANY and a random + // port on sendto. + socketFD = dut.Socket(t, unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + } + defer dut.Close(t, socketFD) + if tc.bindToDevice { + dut.SetSockOpt(t, socketFD, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, []byte(dut.Net.RemoteDevName)) + } + + var ethernetLayer testbench.Ether + if tc.sendToBroadcast { + dut.SetSockOptInt(t, socketFD, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) + + // When sending to broadcast (subnet or limited), the expected ethernet + // address is also broadcast. + ethernetBroadcastAddress := header.EthernetBroadcastAddress + ethernetLayer.DstAddr = ðernetBroadcastAddress + } else if tc.sendTo.IsMulticast() { + ethernetMulticastAddress := header.EthernetAddressFromMulticastIPv4Address(tcpip.Address(tc.sendTo.To4())) + ethernetLayer.DstAddr = ðernetMulticastAddress + } + expectedLayers := testbench.Layers{ðernetLayer} + + var conn udpConn + if sendTo4 := tc.sendTo.To4(); sendTo4 != nil { + v4Conn := dut.Net.NewUDPIPv4(t, outgoingUDP, incomingUDP) + conn = &v4Conn + expectedLayers = append( + expectedLayers, + &testbench.IPv4{ + DstAddr: testbench.Address(tcpip.Address(sendTo4)), + }, + ) + } else { + v6Conn := dut.Net.NewUDPIPv6(t, outgoingUDP, incomingUDP) + conn = &v6Conn + expectedLayers = append( + expectedLayers, + &testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(tc.sendTo)), + }, + ) + } + defer conn.Close(t) + + expectedLayers = append(expectedLayers, &incomingUDP) + for _, v := range []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. + } { + runTc(t, dut, conn, socketFD, tc, v.payload, expectedLayers) } } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 0da35f7be..6ee2b73c1 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -569,7 +569,7 @@ syscall_test( # syscall_test(vfs2="True",test = "//test/syscalls/linux:sigaltstack_test") syscall_test( - test = "//test/syscalls/linux:sigiret_test", + test = "//test/syscalls/linux:sigreturn_test", ) syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d184712e3..2b4b6f348 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -582,6 +582,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", gtest, + "//test/util:fs_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", @@ -797,6 +798,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + "//test/util:capability_util", "//test/util:cleanup", "//test/util:eventfd_util", "//test/util:file_descriptor", @@ -807,6 +809,7 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:save_util", @@ -978,6 +981,7 @@ cc_binary( "//test/util:epoll_util", "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", @@ -2191,11 +2195,11 @@ cc_binary( ) cc_binary( - name = "sigiret_test", + name = "sigreturn_test", testonly = 1, srcs = select_arch( - amd64 = ["sigiret.cc"], - arm64 = [], + amd64 = ["sigreturn_amd64.cc"], + arm64 = ["sigreturn_arm64.cc"], ), linkstatic = 1, deps = [ diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index a06b5cfd6..8233df0f8 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -98,6 +98,42 @@ TEST(ChmodTest, FchmodatBadF) { ASSERT_THAT(fchmodat(-1, "foo", 0444, 0), SyscallFailsWithErrno(EBADF)); } +TEST(ChmodTest, FchmodFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF)); +} + +TEST(ChmodTest, FchmodDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + + ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF)); +} + +TEST(ChmodTest, FchmodatWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Drop capabilities that allow us to override file permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + const auto parent_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(GetAbsoluteTestTmpdir().c_str(), O_PATH | O_DIRECTORY)); + + ASSERT_THAT( + fchmodat(parent_fd.get(), std::string(Basename(temp_file.path())).c_str(), + 0444, 0), + SyscallSucceeds()); + + EXPECT_THAT(open(temp_file.path().c_str(), O_RDWR), + SyscallFailsWithErrno(EACCES)); +} + TEST(ChmodTest, FchmodatNotDir) { ASSERT_THAT(fchmodat(-1, "", 0444, 0), SyscallFailsWithErrno(ENOENT)); } diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc index 5530ad18f..ff0d39343 100644 --- a/test/syscalls/linux/chown.cc +++ b/test/syscalls/linux/chown.cc @@ -48,6 +48,36 @@ TEST(ChownTest, FchownatBadF) { ASSERT_THAT(fchownat(-1, "fff", 0, 0, 0), SyscallFailsWithErrno(EBADF)); } +TEST(ChownTest, FchownFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()), + SyscallFailsWithErrno(EBADF)); +} + +TEST(ChownTest, FchownDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + + ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()), + SyscallFailsWithErrno(EBADF)); +} + +TEST(ChownTest, FchownatWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + const auto dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + ASSERT_THAT( + fchownat(dirfd.get(), file.path().c_str(), geteuid(), getegid(), 0), + SyscallSucceeds()); +} + TEST(ChownTest, FchownatEmptyPath) { const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const auto fd = @@ -209,6 +239,14 @@ INSTANTIATE_TEST_SUITE_P( owner, group, 0); MaybeSave(); return errorFromReturn("fchownat-dirfd", rc); + }, + [](const std::string& path, uid_t owner, gid_t group) -> PosixError { + ASSIGN_OR_RETURN_ERRNO(auto dirfd, Open(std::string(Dirname(path)), + O_DIRECTORY | O_PATH)); + int rc = fchownat(dirfd.get(), std::string(Basename(path)).c_str(), + owner, group, 0); + MaybeSave(); + return errorFromReturn("fchownat-opathdirfd", rc); })); } // namespace diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc index 4f773bc75..ba4e13fb9 100644 --- a/test/syscalls/linux/dup.cc +++ b/test/syscalls/linux/dup.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -44,14 +45,6 @@ PosixErrorOr<FileDescriptor> Dup3(const FileDescriptor& fd, int target_fd, return FileDescriptor(new_fd); } -void CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) { - struct stat stat_result1, stat_result2; - ASSERT_THAT(fstat(fd1.get(), &stat_result1), SyscallSucceeds()); - ASSERT_THAT(fstat(fd2.get(), &stat_result2), SyscallSucceeds()); - EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev); - EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino); -} - TEST(DupTest, Dup) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -59,7 +52,7 @@ TEST(DupTest, Dup) { // Dup the descriptor and make sure it's the same file. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); } TEST(DupTest, DupClearsCloExec) { @@ -70,10 +63,24 @@ TEST(DupTest, DupClearsCloExec) { // Duplicate the descriptor. Ensure that it doesn't have FD_CLOEXEC set. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); } +TEST(DupTest, DupWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Dup the descriptor and make sure it's the same file. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + TEST(DupTest, Dup2) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -82,13 +89,13 @@ TEST(DupTest, Dup2) { FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); // Dup over the file above. int target_fd = nfd.release(); FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd)); EXPECT_EQ(target_fd, nfd2.get()); - CheckSameFile(fd, nfd2); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2)); } TEST(DupTest, Dup2SameFD) { @@ -99,6 +106,28 @@ TEST(DupTest, Dup2SameFD) { ASSERT_THAT(dup2(fd.get(), fd.get()), SyscallSucceedsWithValue(fd.get())); } +TEST(DupTest, Dup2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Regular dup once. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); + + // Dup over the file above. + int target_fd = nfd.release(); + FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd)); + EXPECT_EQ(target_fd, nfd2.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2)); + EXPECT_THAT(fcntl(nfd2.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + TEST(DupTest, Dup3) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -106,16 +135,16 @@ TEST(DupTest, Dup3) { // Regular dup once. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); // Dup over the file above, check that it has no CLOEXEC. nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0)); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); // Dup over the file again, check that it does not CLOEXEC. nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC)); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); } @@ -127,6 +156,32 @@ TEST(DupTest, Dup3FailsSameFD) { ASSERT_THAT(dup3(fd.get(), fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(DupTest, Dup3WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Regular dup once. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + + // Dup over the file above, check that it has no CLOEXEC. + nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0)); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); + + // Dup over the file again, check that it does not CLOEXEC. + nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC)); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/fadvise64.cc b/test/syscalls/linux/fadvise64.cc index 2af7aa6d9..ac24c4066 100644 --- a/test/syscalls/linux/fadvise64.cc +++ b/test/syscalls/linux/fadvise64.cc @@ -45,6 +45,17 @@ TEST(FAdvise64Test, Basic) { SyscallSucceeds()); } +TEST(FAdvise64Test, FAdvise64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL), + SyscallFailsWithErrno(EBADF)); + ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL), + SyscallFailsWithErrno(EBADF)); +} + TEST(FAdvise64Test, InvalidArgs) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index edd23e063..5c839447e 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -108,6 +108,13 @@ TEST_F(AllocateTest, FallocateReadonly) { EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF)); } +TEST_F(AllocateTest, FallocateWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF)); +} + TEST_F(AllocateTest, FallocatePipe) { int pipes[2]; EXPECT_THAT(pipe(pipes), SyscallSucceeds()); diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc index 08bcae1e8..c6675802d 100644 --- a/test/syscalls/linux/fchdir.cc +++ b/test/syscalls/linux/fchdir.cc @@ -71,6 +71,18 @@ TEST(FchdirTest, NotDir) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST(FchdirTest, FchdirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_dir.path(), O_PATH)); + ASSERT_THAT(open(temp_dir.path().c_str(), O_DIRECTORY | O_PATH), + SyscallSucceeds()); + + EXPECT_THAT(fchdir(fd.get()), SyscallSucceeds()); + // Change CWD to a permanent location as temp dirs will be cleaned up. + EXPECT_THAT(chdir("/"), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 4b581045b..4fa6751ff 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -15,6 +15,7 @@ #include <fcntl.h> #include <signal.h> #include <sys/epoll.h> +#include <sys/mman.h> #include <sys/types.h> #include <syscall.h> #include <unistd.h> @@ -35,10 +36,12 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/memory_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/save_util.h" @@ -204,6 +207,41 @@ PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write, return std::move(cleanup); } +TEST(FcntlTest, FcntlDupWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + + int new_fd; + // Dup the descriptor and make sure it's the same file. + EXPECT_THAT(new_fd = fcntl(fd.get(), F_DUPFD, 0), SyscallSucceeds()); + + FileDescriptor nfd = FileDescriptor(new_fd); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(O_PATH)); +} + +TEST(FcntlTest, SetFileStatusFlagWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + EXPECT_THAT(fcntl(fd.get(), F_SETFL, 0), SyscallFailsWithErrno(EBADF)); +} + +TEST(FcntlTest, BadFcntlsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + EXPECT_THAT(fcntl(fd.get(), F_SETOWN, 0), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd.get(), F_GETOWN, 0), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(fcntl(fd.get(), F_SETOWN_EX, 0), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd.get(), F_GETOWN_EX, 0), SyscallFailsWithErrno(EBADF)); +} + TEST(FcntlTest, SetCloExecBadFD) { // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); @@ -223,6 +261,32 @@ TEST(FcntlTest, SetCloExec) { ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); } +TEST(FcntlTest, SetCloExecWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Open a file descriptor with FD_CLOEXEC descriptor flag not set. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + + // Set the FD_CLOEXEC flag. + ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds()); + ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); +} + +TEST(FcntlTest, DupFDCloExecWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Open a file descriptor with FD_CLOEXEC descriptor flag not set. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + int nfd; + ASSERT_THAT(nfd = fcntl(fd.get(), F_DUPFD_CLOEXEC, 0), SyscallSucceeds()); + FileDescriptor dup_fd(nfd); + + // Check for the FD_CLOEXEC flag. + ASSERT_THAT(fcntl(dup_fd.get(), F_GETFD), + SyscallSucceedsWithValue(FD_CLOEXEC)); +} + TEST(FcntlTest, ClearCloExec) { // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set. FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC)); @@ -264,6 +328,22 @@ TEST(FcntlTest, GetAllFlags) { EXPECT_EQ(rflags, expected); } +// When O_PATH is specified in flags, flag bits other than O_CLOEXEC, +// O_DIRECTORY, and O_NOFOLLOW are ignored. +TEST(FcntlTest, GetOpathFlag) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + int flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND | O_PATH | + O_NOFOLLOW | O_DIRECTORY; + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), flags)); + + int expected = O_PATH | O_NOFOLLOW | O_DIRECTORY; + + int rflags; + EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + EXPECT_EQ(rflags, expected); +} + TEST(FcntlTest, SetFlags) { TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), 0)); @@ -392,6 +472,22 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsRead) { EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallFailsWithErrno(EBADF)); } +TEST_F(FcntlLockTest, SetLockWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + struct flock fl0; + fl0.l_type = F_WRLCK; + fl0.l_whence = SEEK_SET; + fl0.l_start = 0; + fl0.l_len = 0; // Lock all file + + // Expect that setting a write lock using a Opath file descriptor + // won't work. + EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallFailsWithErrno(EBADF)); +} + TEST_F(FcntlLockTest, SetLockUnlockOnNothing) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = @@ -1642,6 +1738,202 @@ TEST(FcntlTest, SetFlSetOwnSetSigDoNotRace) { } } +TEST_F(FcntlLockTest, GetLockOnNothing) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_RDLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds()); + ASSERT_TRUE(fl.l_type == F_UNLCK); +} + +TEST_F(FcntlLockTest, GetLockOnLockSameProcess) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_RDLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds()); + ASSERT_TRUE(fl.l_type == F_UNLCK); + + fl.l_type = F_WRLCK; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + ASSERT_THAT(fcntl(fd.get(), F_GETLK, &fl), SyscallSucceeds()); + ASSERT_TRUE(fl.l_type == F_UNLCK); +} + +TEST_F(FcntlLockTest, GetReadLockOnReadLock) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_RDLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + + pid_t child_pid = fork(); + if (child_pid == 0) { + TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0); + TEST_CHECK(fl.l_type == F_UNLCK); + _exit(0); + } + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +TEST_F(FcntlLockTest, GetReadLockOnWriteLock) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + + fl.l_type = F_RDLCK; + pid_t child_pid = fork(); + if (child_pid == 0) { + TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0); + TEST_CHECK(fl.l_type == F_WRLCK); + TEST_CHECK(fl.l_pid == getppid()); + _exit(0); + } + + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +TEST_F(FcntlLockTest, GetWriteLockOnReadLock) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_RDLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + + fl.l_type = F_WRLCK; + pid_t child_pid = fork(); + if (child_pid == 0) { + TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0); + TEST_CHECK(fl.l_type == F_RDLCK); + TEST_CHECK(fl.l_pid == getppid()); + _exit(0); + } + + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +TEST_F(FcntlLockTest, GetWriteLockOnWriteLock) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + + struct flock fl; + fl.l_type = F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + + pid_t child_pid = fork(); + if (child_pid == 0) { + TEST_CHECK(fcntl(fd.get(), F_GETLK, &fl) >= 0); + TEST_CHECK(fl.l_type == F_WRLCK); + TEST_CHECK(fl.l_pid == getppid()); + _exit(0); + } + + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Tests that the pid returned from F_GETLK is relative to the caller's PID +// namespace. +TEST_F(FcntlLockTest, GetLockRespectsPIDNamespace) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + std::string filename = file.path(); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_RDWR, 0666)); + + // Lock in the parent process. + struct flock fl; + fl.l_type = F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + ASSERT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + + auto child_getlk = [](void* filename) { + int fd = open((char*)filename, O_RDWR, 0666); + TEST_CHECK(fd >= 0); + + struct flock fl; + fl.l_type = F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 40; + TEST_CHECK(fcntl(fd, F_GETLK, &fl) >= 0); + TEST_CHECK(fl.l_type == F_WRLCK); + // Parent PID should be 0 in the child PID namespace. + TEST_CHECK(fl.l_pid == 0); + close(fd); + return 0; + }; + + // Set up child process in a new PID namespace. + constexpr int kStackSize = 4096; + Mapping stack = ASSERT_NO_ERRNO_AND_VALUE( + Mmap(nullptr, kStackSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0)); + pid_t child_pid; + ASSERT_THAT( + child_pid = clone(child_getlk, (char*)stack.ptr() + stack.len(), + CLONE_NEWPID | SIGCHLD, (void*)filename.c_str()), + SyscallSucceeds()); + + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index 93c692dd6..2f2b14037 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -429,6 +429,32 @@ TYPED_TEST(GetdentsTest, NotDir) { SyscallFailsWithErrno(ENOTDIR)); } +// Test that getdents returns EBADF when called on an opath file. +TYPED_TEST(GetdentsTest, OpathFile) { + SKIP_IF(IsRunningWithVFS1()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + typename TestFixture::DirentBufferType dirents(256); + EXPECT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), + dirents.Size()), + SyscallFailsWithErrno(EBADF)); +} + +// Test that getdents returns EBADF when called on an opath directory. +TYPED_TEST(GetdentsTest, OpathDirectory) { + SKIP_IF(IsRunningWithVFS1()); + + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_PATH | O_DIRECTORY)); + + typename TestFixture::DirentBufferType dirents(256); + ASSERT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), + dirents.Size()), + SyscallFailsWithErrno(EBADF)); +} + // Test that SEEK_SET to 0 causes getdents to re-read the entries. TYPED_TEST(GetdentsTest, SeekResetsCursor) { // . and .. should be in an otherwise empty directory. diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 8137f0e29..a88c89e20 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -36,6 +36,7 @@ #include "test/util/epoll_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -315,8 +316,7 @@ PosixErrorOr<std::vector<Event>> DrainEvents(int fd) { } PosixErrorOr<FileDescriptor> InotifyInit1(int flags) { - int fd; - EXPECT_THAT(fd = inotify_init1(flags), SyscallSucceeds()); + int fd = inotify_init1(flags); if (fd < 0) { return PosixError(errno, "inotify_init1() failed"); } @@ -325,9 +325,7 @@ PosixErrorOr<FileDescriptor> InotifyInit1(int flags) { PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path, uint32_t mask) { - int wd; - EXPECT_THAT(wd = inotify_add_watch(fd, path.c_str(), mask), - SyscallSucceeds()); + int wd = inotify_add_watch(fd, path.c_str(), mask); if (wd < 0) { return PosixError(errno, "inotify_add_watch() failed"); } @@ -784,6 +782,38 @@ TEST(Inotify, MoveWatchedTargetGeneratesEvents) { EXPECT_EQ(events[0].cookie, events[1].cookie); } +// Tests that close events are only emitted when a file description drops its +// last reference. +TEST(Inotify, DupFD) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_OPEN, wd), + })); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); + + fd2.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_CLOSE_NOWRITE, wd), + })); +} + TEST(Inotify, CoalesceEvents) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const FileDescriptor fd = @@ -1779,11 +1809,9 @@ TEST(Inotify, Sendfile) { EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)})); } -// On Linux, inotify behavior is not very consistent with splice(2). We try our -// best to emulate Linux for very basic calls to splice. TEST(Inotify, SpliceOnWatchTarget) { - int pipes[2]; - ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); + int pipefds[2]; + ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds()); const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const FileDescriptor inotify_fd = @@ -1798,15 +1826,20 @@ TEST(Inotify, SpliceOnWatchTarget) { const int file_wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); - EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, 1, /*flags=*/0), + EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr, 1, /*flags=*/0), SyscallSucceedsWithValue(1)); - // Surprisingly, events are not generated in Linux if we read from a file. + // Surprisingly, events may not be generated in Linux if we read from a file. + // fs/splice.c:generic_file_splice_read, which is used most often, does not + // generate events, whereas fs/splice.c:default_file_splice_read does. std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - ASSERT_THAT(events, Are({})); + if (IsRunningOnGvisor() && !IsRunningWithVFS1()) { + ASSERT_THAT(events, Are({Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd)})); + } - EXPECT_THAT(splice(pipes[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0), + EXPECT_THAT(splice(pipefds[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0), SyscallSucceedsWithValue(1)); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); @@ -1817,8 +1850,8 @@ TEST(Inotify, SpliceOnWatchTarget) { } TEST(Inotify, SpliceOnInotifyFD) { - int pipes[2]; - ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); + int pipefds[2]; + ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds()); const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const FileDescriptor fd = @@ -1834,11 +1867,11 @@ TEST(Inotify, SpliceOnInotifyFD) { char buf; EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); - EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, + EXPECT_THAT(splice(fd.get(), nullptr, pipefds[1], nullptr, sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK), SyscallSucceedsWithValue(sizeof(struct inotify_event))); - const FileDescriptor read_fd(pipes[0]); + const FileDescriptor read_fd(pipefds[0]); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get())); ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)})); @@ -1936,24 +1969,29 @@ TEST(Inotify, Xattr) { } TEST(Inotify, Exec) { - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath bin = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(dir.path(), "/bin/true")); - + SKIP_IF(IsRunningWithVFS1()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), bin.path(), IN_ALL_EVENTS)); + InotifyAddWatch(fd.get(), "/bin/true", IN_ALL_EVENTS)); // Perform exec. - ScopedThread t([&bin]() { - ASSERT_THAT(execl(bin.path().c_str(), bin.path().c_str(), (char*)nullptr), - SyscallSucceeds()); - }); - t.Join(); + pid_t child = -1; + int execve_errno = -1; + auto kill = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec("/bin/true", {}, {}, nullptr, &child, &execve_errno)); + ASSERT_EQ(0, execve_errno); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_EQ(0, status); + + // Process cleanup no longer needed. + kill.Release(); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd)})); + EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd), + Event(IN_CLOSE_NOWRITE, wd)})); } // Watches without IN_EXCL_UNLINK, should continue to emit events for file diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index b0a07a064..9b16d1558 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -76,6 +76,19 @@ TEST_F(IoctlTest, InvalidControlNumber) { EXPECT_THAT(ioctl(STDOUT_FILENO, 0), SyscallFailsWithErrno(ENOTTY)); } +TEST_F(IoctlTest, IoctlWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_PATH)); + + int set = 1; + EXPECT_THAT(ioctl(fd.get(), FIONBIO, &set), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(ioctl(fd.get(), FIONCLEX), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(ioctl(fd.get(), FIOCLEX), SyscallFailsWithErrno(EBADF)); +} + TEST_F(IoctlTest, FIONBIOSucceeds) { EXPECT_FALSE(CheckNonBlocking(fd())); int set = 1; diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc index 544681168..4f9ca1a65 100644 --- a/test/syscalls/linux/link.cc +++ b/test/syscalls/linux/link.cc @@ -50,6 +50,8 @@ bool IsSameFile(const std::string& f1, const std::string& f2) { return stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino; } +// TODO(b/178640646): Add test for linkat with AT_EMPTY_PATH + TEST(LinkTest, CanCreateLinkFile) { auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const std::string newname = NewTempAbsPath(); @@ -235,6 +237,59 @@ TEST(LinkTest, AbsPathsWithNonDirFDs) { SyscallSucceeds()); } +TEST(LinkTest, NewDirFDWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string newname_parent = NewTempAbsPath(); + const std::string newname_base = "child"; + const std::string newname = JoinPath(newname_parent, newname_base); + + // Create newname_parent directory, and get an FD. + EXPECT_THAT(mkdir(newname_parent.c_str(), 0777), SyscallSucceeds()); + const FileDescriptor newname_parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(newname_parent, O_DIRECTORY | O_PATH)); + + // Link newname to oldfile, using newname_parent_fd. + EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), newname_parent_fd.get(), + newname.c_str(), 0), + SyscallSucceeds()); + + EXPECT_TRUE(IsSameFile(oldfile.path(), newname)); +} + +TEST(LinkTest, RelPathsNonDirFDsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + // Create a file that will be passed as the directory fd for old/new names. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + // Using file_fd as olddirfd will fail. + EXPECT_THAT(linkat(file_fd.get(), "foo", AT_FDCWD, "bar", 0), + SyscallFailsWithErrno(ENOTDIR)); + + // Using file_fd as newdirfd will fail. + EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), file_fd.get(), "bar", 0), + SyscallFailsWithErrno(ENOTDIR)); +} + +TEST(LinkTest, AbsPathsNonDirFDsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string newname = NewTempAbsPath(); + + // Create a file that will be passed as the directory fd for old/new names. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + // Using file_fd as the dirfds is OK as long as paths are absolute. + EXPECT_THAT(linkat(file_fd.get(), oldfile.path().c_str(), file_fd.get(), + newname.c_str(), 0), + SyscallSucceeds()); +} + TEST(LinkTest, LinkDoesNotFollowSymlinks) { // Create oldfile, and oldsymlink which points to it. auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 5a1973f60..6e714b12c 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -179,9 +179,9 @@ TEST(MadviseDontforkTest, DontforkShared) { // First page is mapped in child and modifications are visible to parent // via the shared mapping. TEST_CHECK(IsMapped(ms1.addr())); - ExpectAllMappingBytes(ms1, 2); + CheckAllMappingBytes(ms1, 2); memset(ms1.ptr(), 1, kPageSize); - ExpectAllMappingBytes(ms1, 1); + CheckAllMappingBytes(ms1, 1); // Second page must not be mapped in child. TEST_CHECK(!IsMapped(ms2.addr())); @@ -222,9 +222,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) { // page. The mapping is private so the modifications are not visible to // the parent. TEST_CHECK(IsMapped(mp1.addr())); - ExpectAllMappingBytes(mp1, 1); + CheckAllMappingBytes(mp1, 1); memset(mp1.ptr(), 11, kPageSize); - ExpectAllMappingBytes(mp1, 11); + CheckAllMappingBytes(mp1, 11); // Verify second page is not mapped. TEST_CHECK(!IsMapped(mp2.addr())); @@ -233,9 +233,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) { // page. The mapping is private so the modifications are not visible to // the parent. TEST_CHECK(IsMapped(mp3.addr())); - ExpectAllMappingBytes(mp3, 3); + CheckAllMappingBytes(mp3, 3); memset(mp3.ptr(), 13, kPageSize); - ExpectAllMappingBytes(mp3, 13); + CheckAllMappingBytes(mp3, 13); }; EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 83546830d..93a6d9cde 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -930,6 +930,18 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { SyscallFailsWithErrno(EACCES)); } +// Mmap not allowed on O_PATH FDs. +TEST_F(MMapFileTest, MmapFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + uintptr_t addr; + EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0), + SyscallFailsWithErrno(EBADF)); +} + // The FD must be readable. TEST_P(MMapFileParamTest, WriteOnlyFd) { const FileDescriptor fd = diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index fcd162ca2..733b17834 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -45,7 +45,7 @@ namespace { // * O_CREAT // * O_DIRECTORY // * O_NOFOLLOW -// * O_PATH <- Will we ever support this? +// * O_PATH // // Special operations on open: // * O_EXCL @@ -517,6 +517,26 @@ TEST_F(OpenTest, OpenWithStrangeFlags) { EXPECT_THAT(read(fd.get(), &c, 1), SyscallFailsWithErrno(EBADF)); } +TEST_F(OpenTest, OpenWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + const DisableSave ds; // Permissions are dropped. + std::string path = NewTempAbsPath(); + + // Create a file without user permissions. + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055)); + + // Cannot open file as read only because we are owner and have no permissions + // set. + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + + // Can open file with O_PATH because don't need permissions on the object when + // opening with O_PATH. + ASSERT_NO_ERRNO(Open(path, O_PATH)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 2ed4f6f9c..d25be0e30 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -548,13 +548,7 @@ TEST_P(RawPacketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack - // matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc index a9bfdb37b..999c8ab6b 100644 --- a/test/syscalls/linux/ping_socket.cc +++ b/test/syscalls/linux/ping_socket.cc @@ -31,51 +31,36 @@ namespace gvisor { namespace testing { namespace { -class PingSocket : public ::testing::Test { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void PingSocket::SetUp() { - // On some hosts ping sockets are restricted to specific groups using the - // sysctl "ping_group_range". - int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP); - if (s < 0 && errno == EPERM) { - GTEST_SKIP(); - } - close(s); - - addr_ = {}; - // Just a random port as the destination port number is irrelevant for ping - // sockets. - addr_.sin_port = 12345; - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void PingSocket::TearDown() {} - // Test ICMP port exhaustion returns EAGAIN. // // We disable both random/cooperative S/R for this test as it makes way too many // syscalls. -TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) { +TEST(PingSocket, ICMPPortExhaustion_NoRandomSave) { DisableSave ds; + + { + auto s = Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP); + if (!s.ok()) { + ASSERT_EQ(s.error().errno_value(), EACCES); + GTEST_SKIP(); + } + } + + const struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_addr = + { + .s_addr = htonl(INADDR_LOOPBACK), + }, + }; + std::vector<FileDescriptor> sockets; constexpr int kSockets = 65536; - addr_.sin_port = 0; for (int i = 0; i < kSockets; i++) { auto s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); - int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)); + int ret = connect(s.get(), reinterpret_cast<const struct sockaddr*>(&addr), + sizeof(addr)); if (ret == 0) { sockets.push_back(std::move(s)); continue; diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index bcdbbb044..c74990ba1 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -77,6 +77,16 @@ TEST_F(Pread64Test, WriteOnlyNotReadable) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF)); } +TEST_F(Pread64Test, Pread64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + char buf[1024]; + EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF)); +} + TEST_F(Pread64Test, DirNotReadable) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index 5b0743fe9..1c40f0915 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -89,6 +89,20 @@ TEST(PreadvTest, MMConcurrencyStress) { // The test passes if it neither deadlocks nor crashes the OS. } +// This test calls preadv with an O_PATH fd. +TEST(PreadvTest, PreadvWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + struct iovec iov; + iov.iov_base = nullptr; + iov.iov_len = 0; + + EXPECT_THAT(preadv(fd.get(), &iov, 1, 0), SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index 4a9acd7ae..cb58719c4 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -226,6 +226,24 @@ TEST(Preadv2Test, TestUnreadableFile) { SyscallFailsWithErrno(EBADF)); } +// This test calls preadv2 with a file opened with O_PATH. +TEST(Preadv2Test, Preadv2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + auto iov = absl::make_unique<struct iovec[]>(1); + iov[0].iov_base = nullptr; + iov[0].iov_len = 0; + + EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/0, + /*flags=*/0), + SyscallFailsWithErrno(EBADF)); +} + // Calling preadv2 with a non-negative offset calls preadv. Calling preadv with // an unseekable file is not allowed. A pipe is used for an unseekable file. TEST(Preadv2Test, TestUnseekableFileInvalid) { diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 0b174e2be..85ff258df 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1338,6 +1338,7 @@ TEST_F(JobControlTest, SetTTYDifferentSession) { TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); TEST_PCHECK(gcwstatus == 0); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTY) { @@ -1515,7 +1516,8 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - creates a child process in a new process group // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. -TEST_F(JobControlTest, SetForegroundProcessGroup) { +// TODO(gvisor.dev/issue/5357): Fix and enable. +TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroup) { auto res = RunInChild([=]() { TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); @@ -1557,6 +1559,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) { TEST_PCHECK(pgid = getpgid(0) == 0); TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { @@ -1576,8 +1579,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { ASSERT_NO_ERRNO(ret); } -TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - auto ret = RunInChild([=]() { +// TODO(gvisor.dev/issue/5357): Fix and enable. +TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroupEmptyProcessGroup) { + auto res = RunInChild([=]() { TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); // Create a new process, put it in a new process group, make that group the @@ -1595,6 +1599,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && errno == ESRCH); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index e69794910..1b2f25363 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -77,6 +77,17 @@ TEST_F(Pwrite64, Overflow) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Pwrite64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + std::vector<char> buf(1); + EXPECT_THAT(PwriteFd(fd.get(), buf.data(), 1, 0), + SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index 63b686c62..00aed61b4 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -283,6 +283,23 @@ TEST(Pwritev2Test, ReadOnlyFile) { SyscallFailsWithErrno(EBADF)); } +TEST(Pwritev2Test, Pwritev2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + char buf[16]; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + + EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0), + SyscallFailsWithErrno(EBADF)); +} + // This test calls pwritev2 with an invalid flag. TEST(Pwritev2Test, InvalidFlag) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 955bcee4b..32924466f 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -621,13 +621,7 @@ TEST_P(RawSocketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack - // matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 2633ba31b..98d5e432d 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -112,6 +112,15 @@ TEST_F(ReadTest, ReadDirectoryFails) { EXPECT_THAT(ReadFd(file.get(), buf.data(), 1), SyscallFailsWithErrno(EISDIR)); } +TEST_F(ReadTest, ReadWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + std::vector<char> buf(1); + EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index baaf9f757..86808d255 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -251,6 +251,20 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { SyscallFailsWithErrno(EFAULT)); } +TEST_F(ReadvTest, ReadvWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + char buffer[1024]; + struct iovec iov[1]; + iov[0].iov_base = buffer; + iov[0].iov_len = 1024; + + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + + ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EBADF)); +} + // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index d6e8b3e59..baf794152 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -256,32 +256,26 @@ TEST(ShmTest, IpcInfo) { } TEST(ShmTest, ShmInfo) { - struct shm_info info; - - // We generally can't know what other processes on a linux machine - // does with shared memory segments, so we can't test specific - // numbers on Linux. When running under gvisor, we're guaranteed to - // be the only ones using shm, so we can easily verify machine-wide - // numbers. - if (IsRunningOnGvisor()) { - ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &info)); - EXPECT_EQ(info.used_ids, 0); - EXPECT_EQ(info.shm_tot, 0); - EXPECT_EQ(info.shm_rss, 0); - EXPECT_EQ(info.shm_swp, 0); - } + // Take a snapshot of the system before the test runs. + struct shm_info snap; + ASSERT_NO_ERRNO(Shmctl(0, SHM_INFO, &snap)); const ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); const char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); + struct shm_info info; ASSERT_NO_ERRNO(Shmctl(1, SHM_INFO, &info)); + // We generally can't know what other processes on a linux machine do with + // shared memory segments, so we can't test specific numbers on Linux. When + // running under gvisor, we're guaranteed to be the only ones using shm, so + // we can easily verify machine-wide numbers. if (IsRunningOnGvisor()) { ASSERT_NO_ERRNO(Shmctl(shm.id(), SHM_INFO, &info)); - EXPECT_EQ(info.used_ids, 1); - EXPECT_EQ(info.shm_tot, kAllocSize / kPageSize); - EXPECT_EQ(info.shm_rss, kAllocSize / kPageSize); + EXPECT_EQ(info.used_ids, snap.used_ids + 1); + EXPECT_EQ(info.shm_tot, snap.shm_tot + (kAllocSize / kPageSize)); + EXPECT_EQ(info.shm_rss, snap.shm_rss + (kAllocSize / kPageSize)); EXPECT_EQ(info.shm_swp, 0); // Gvisor currently never swaps. } @@ -378,18 +372,18 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { SetupGvisorDeathTest(); const auto rest = [&] { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( + ShmSegment shm = TEST_CHECK_NO_ERRNO_AND_VALUE( Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); + char* addr = TEST_CHECK_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); // Mark the segment as destroyed so it's automatically cleaned up when we // crash below. We can't rely on the standard cleanup since the destructor // will not run after the SIGSEGV. Note that this doesn't destroy the // segment immediately since we're still attached to it. - ASSERT_NO_ERRNO(shm.Rmid()); + TEST_CHECK_NO_ERRNO(shm.Rmid()); addr[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr)); + TEST_CHECK_NO_ERRNO(Shmdt(addr)); // This access should cause a SIGSEGV. addr[0] = 'x'; diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigreturn_amd64.cc index 6227774a4..6227774a4 100644 --- a/test/syscalls/linux/sigiret.cc +++ b/test/syscalls/linux/sigreturn_amd64.cc diff --git a/test/syscalls/linux/sigreturn_arm64.cc b/test/syscalls/linux/sigreturn_arm64.cc new file mode 100644 index 000000000..2c19e2984 --- /dev/null +++ b/test/syscalls/linux/sigreturn_arm64.cc @@ -0,0 +1,97 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/unistd.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <sys/ucontext.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/logging.h" +#include "test/util/signal_util.h" +#include "test/util/test_util.h" +#include "test/util/timer_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr uint64_t kOrigX7 = 0xdeadbeeffacefeed; + +void sigvtalrm(int sig, siginfo_t* siginfo, void* _uc) { + ucontext_t* uc = reinterpret_cast<ucontext_t*>(_uc); + + // Verify that: + // - x7 value in mcontext_t matches kOrigX7. + if (uc->uc_mcontext.regs[7] == kOrigX7) { + // Modify the value x7 in the ucontext. This is the value seen by the + // application after the signal handler returns. + uc->uc_mcontext.regs[7] = ~kOrigX7; + } +} + +int testX7(uint64_t* val, uint64_t sysno, uint64_t tgid, uint64_t tid, + uint64_t signo) { + register uint64_t* x9 __asm__("x9") = val; + register uint64_t x8 __asm__("x8") = sysno; + register uint64_t x0 __asm__("x0") = tgid; + register uint64_t x1 __asm__("x1") = tid; + register uint64_t x2 __asm__("x2") = signo; + + // Initialize x7, send SIGVTALRM to itself and read x7. + __asm__( + "ldr x7, [x9, 0]\n" + "svc 0\n" + "str x7, [x9, 0]\n" + : "=r"(x0) + : "r"(x0), "r"(x1), "r"(x2), "r"(x9), "r"(x8) + : "x7"); + return x0; +} + +// On ARM64, when ptrace stops on a system call, it uses the x7 register to +// indicate whether the stop has been signalled from syscall entry or syscall +// exit. This means that we can't get a value of this register and we can't +// change it. More details are in the comment for tracehook_report_syscall in +// arch/arm64/kernel/ptrace.c. +// +// CheckR7 checks that the ptrace platform handles the x7 register properly. +TEST(SigreturnTest, CheckX7) { + // Setup signal handler for SIGVTALRM. + struct sigaction sa = {}; + sigfillset(&sa.sa_mask); + sa.sa_sigaction = sigvtalrm; + sa.sa_flags = SA_SIGINFO; + auto const action_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGVTALRM, sa)); + + auto const mask_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGVTALRM)); + + uint64_t x7 = kOrigX7; + + testX7(&x7, __NR_tgkill, getpid(), syscall(__NR_gettid), SIGVTALRM); + + // The following check verifies that %x7 was not clobbered + // when returning from the signal handler (via sigreturn(2)). + EXPECT_EQ(x7, ~kOrigX7); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index 32f583581..b616c2c87 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -16,6 +16,7 @@ #include <sys/stat.h> #include <sys/statfs.h> #include <sys/types.h> +#include <sys/wait.h> #include <unistd.h> #include "gtest/gtest.h" @@ -90,8 +91,7 @@ TEST(SocketTest, UnixSocketStat) { EXPECT_EQ(statbuf.st_mode, S_IFSOCK | (sock_perm & ~mask)); // Timestamps should be equal and non-zero. - // TODO(b/158882152): Sockets currently don't implement timestamps. - if (!IsRunningOnGvisor()) { + if (!IsRunningWithVFS1()) { EXPECT_NE(statbuf.st_atime, 0); EXPECT_EQ(statbuf.st_atime, statbuf.st_mtime); EXPECT_EQ(statbuf.st_atime, statbuf.st_ctime); @@ -111,6 +111,77 @@ TEST(SocketTest, UnixSocketStatFS) { EXPECT_EQ(st.f_namelen, NAME_MAX); } +TEST(SocketTest, UnixSCMRightsOnlyPassedOnce_NoRandomSave) { + const DisableSave ds; + + int sockets[2]; + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets), SyscallSucceeds()); + // Send more than what will fit inside the send/receive buffers, so that it is + // split into multiple messages. + constexpr int kBufSize = 0x100000; + + pid_t pid = fork(); + if (pid == 0) { + TEST_PCHECK(close(sockets[0]) == 0); + + // Construct a message with some control message. + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int))] = {}; + std::vector<char> buf(kBufSize); + struct iovec iov = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + ((int*)CMSG_DATA(cmsg))[0] = sockets[1]; + + iov.iov_base = buf.data(); + iov.iov_len = kBufSize; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + int n = sendmsg(sockets[1], &msg, 0); + TEST_PCHECK(n == kBufSize); + TEST_PCHECK(shutdown(sockets[1], SHUT_RDWR) == 0); + TEST_PCHECK(close(sockets[1]) == 0); + _exit(0); + } + + close(sockets[1]); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int))] = {}; + std::vector<char> buf(kBufSize); + struct iovec iov = {}; + msg.msg_control = &control; + msg.msg_controllen = sizeof(control); + + iov.iov_base = buf.data(); + iov.iov_len = kBufSize; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + // The control message should only be present in the first message received. + int n; + ASSERT_THAT(n = recvmsg(sockets[0], &msg, 0), SyscallSucceeds()); + ASSERT_GT(n, 0); + ASSERT_EQ(msg.msg_controllen, CMSG_SPACE(sizeof(int))); + + while (n > 0) { + ASSERT_THAT(n = recvmsg(sockets[0], &msg, 0), SyscallSucceeds()); + ASSERT_EQ(msg.msg_controllen, 0); + } + + close(sockets[0]); + + int status; + ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); + ASSERT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + using SocketOpenTest = ::testing::TestWithParam<int>; // UDS cannot be opened. diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 831d96262..a73987a7e 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -65,6 +65,33 @@ TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { SyscallSucceeds()); } +TEST_P(TCPSocketPairTest, CheckTcpInfoFields) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char buf[10] = {}; + ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + // Wait until second_fd sees the data and then recv it. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0}; + constexpr int kPollTimeoutMs = 2000; // Wait up to 2 seconds for the data. + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + struct tcp_info opt = {}; + socklen_t optLen = sizeof(opt); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen), + SyscallSucceeds()); + + // Validates the received tcp_info fields. + EXPECT_EQ(opt.tcpi_ca_state, 0); + EXPECT_GT(opt.tcpi_snd_cwnd, 0); + EXPECT_GT(opt.tcpi_rto, 0); +} + // This test validates that an RST is sent instead of a FIN when data is // unread on calls to close(2). TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadData) { diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index e557572a7..8eec31a46 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2513,11 +2513,7 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } - + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index a16899493..22a4ee0d1 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -362,7 +362,7 @@ TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) { // BasicFDPassUnalignedRecv starts off by sending a single FD just like // BasicFDPass. The difference is that when calling recvmsg, the length of the -// receive data is only aligned on a 4 byte boundry instead of the normal 8. +// receive data is only aligned on a 4 byte boundary instead of the normal 8. TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 6e7142a42..72f888659 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -221,6 +221,43 @@ TEST_F(StatTest, TrailingSlashNotCleanedReturnsENOTDIR) { EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(StatTest, FstatFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + struct stat st; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + + // Stat the directory. + ASSERT_THAT(fstat(fd.get(), &st), SyscallSucceeds()); +} + +TEST_F(StatTest, FstatDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + struct stat st; + TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY)); + + // Stat the directory. + ASSERT_THAT(fstat(dirfd.get(), &st), SyscallSucceeds()); +} + +// fstatat with an O_PATH fd +TEST_F(StatTest, FstatatDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY)); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + struct stat st = {}; + EXPECT_THAT(fstatat(dirfd.get(), tmpfile.path().c_str(), &st, 0), + SyscallSucceeds()); + EXPECT_FALSE(S_ISDIR(st.st_mode)); + EXPECT_TRUE(S_ISREG(st.st_mode)); +} + // Test fstatating a symlink directory. TEST_F(StatTest, FstatatSymlinkDir) { // Create a directory and symlink to it. diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index f0fb166bd..d4ea8e026 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -64,6 +64,16 @@ TEST(FstatfsTest, InternalTmpfs) { EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); } +TEST(FstatfsTest, CanStatFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH)); + + struct statfs st; + EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); +} + TEST(FstatfsTest, InternalDevShm) { auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index 4d9eba7f0..ea219a091 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -269,6 +269,36 @@ TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { EXPECT_THAT(close(dirfd), SyscallSucceeds()); } +TEST(SymlinkTest, SymlinkAtDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string filepath = NewTempAbsPathInDir(dir.path()); + const std::string base = std::string(Basename(filepath)); + FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH)); + + EXPECT_THAT(symlinkat("/dangling", dirfd.get(), base.c_str()), + SyscallSucceeds()); +} + +TEST(SymlinkTest, ReadlinkAtDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string filepath = NewTempAbsPathInDir(dir.path()); + const std::string base = std::string(Basename(filepath)); + ASSERT_THAT(symlink("/dangling", filepath.c_str()), SyscallSucceeds()); + + FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH)); + + std::vector<char> buf(1024); + int linksize; + EXPECT_THAT( + linksize = readlinkat(dirfd.get(), base.c_str(), buf.data(), 1024), + SyscallSucceeds()); + EXPECT_EQ(0, strncmp("/dangling", buf.data(), linksize)); +} + TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) { // Drop capabilities that allow us to override file and directory permissions. ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc index 8aa2525a9..84a2c4ed7 100644 --- a/test/syscalls/linux/sync.cc +++ b/test/syscalls/linux/sync.cc @@ -49,10 +49,20 @@ TEST(SyncTest, SyncFromPipe) { EXPECT_THAT(close(pipes[1]), SyscallSucceeds()); } -TEST(SyncTest, CannotSyncFileSytemAtBadFd) { +TEST(SyncTest, CannotSyncFileSystemAtBadFd) { EXPECT_THAT(syncfs(-1), SyscallFailsWithErrno(EBADF)); } +TEST(SyncTest, CannotSyncFileSystemAtOpathFD) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + EXPECT_THAT(syncfs(fd.get()), SyscallFailsWithErrno(EBADF)); +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 714848b8e..9028ab024 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -2008,6 +2008,29 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { EXPECT_EQ(got, 0); } +// Tests that connecting to an unspecified address results in ECONNREFUSED. +TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + memset(&addr, 0, addrlen); + addr.ss_family = GetParam(); + auto do_connect = [&addr, addrlen]() { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(addr.ss_family, SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), + addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); + }; + do_connect(); + // Test the v4 mapped address as well. + if (GetParam() == AF_INET6) { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr.s6_addr[10] = sin6->sin6_addr.s6_addr[11] = 0xff; + do_connect(); + } +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index bfc95ed38..17832c47d 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -196,6 +196,16 @@ TEST(TruncateTest, FtruncateNonWriteable) { EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(TruncateTest, FtruncateWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH)); + EXPECT_THAT(ftruncate(fd.get(), 0), AnyOf(SyscallFailsWithErrno(EBADF), + SyscallFailsWithErrno(EINVAL))); +} + // ftruncate(2) should succeed as long as the file descriptor is writeable, // regardless of whether the file permissions allow writing. TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) { diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 77bcfbb8a..740992d0a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -218,6 +218,44 @@ TEST_F(WriteTest, PwriteNoChangeOffset) { EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); } +TEST_F(WriteTest, WriteWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallFailsWithErrno(EBADF)); +} + +TEST_F(WriteTest, WritevWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + char buf[16]; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + + EXPECT_THAT(writev(fd, &iov, /*__count=*/1), SyscallFailsWithErrno(EBADF)); +} + +TEST_F(WriteTest, PwriteWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + const std::string data = "hello world\n"; + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index bd3f829c4..a953a55fe 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -607,6 +607,27 @@ TEST_F(XattrTest, XattrWithFD) { EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); } +TEST_F(XattrTest, XattrWithOPath) { + SKIP_IF(IsRunningWithVFS1()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), O_PATH)); + const char name[] = "user.test"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EBADF)); + + int buf; + EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size), + SyscallFailsWithErrno(EBADF)); + + char list[sizeof(name)]; + EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)), + SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(fremovexattr(fd.get(), name), SyscallFailsWithErrno(EBADF)); +} + TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { // Trusted namespace not supported in VFS1. SKIP_IF(IsRunningWithVFS1()); diff --git a/test/util/BUILD b/test/util/BUILD index 1b028a477..e561f3daa 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -172,6 +172,7 @@ cc_library( ":posix_error", ":save_util", ":test_util", + gtest, "@com_google_absl//absl/strings", ], ) diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index b16055dd8..5f1ce0d8a 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -663,5 +663,21 @@ PosixErrorOr<bool> IsOverlayfs(const std::string& path) { return stat.f_type == OVERLAYFS_SUPER_MAGIC; } +PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) { + struct stat stat_result1, stat_result2; + int res = fstat(fd1.get(), &stat_result1); + if (res < 0) { + return PosixError(errno, absl::StrCat("fstat ", fd1.get())); + } + + res = fstat(fd2.get(), &stat_result2); + if (res < 0) { + return PosixError(errno, absl::StrCat("fstat ", fd2.get())); + } + EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev); + EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino); + + return NoError(); +} } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h index c99cf5eb7..2190c3bca 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -191,6 +191,8 @@ PosixErrorOr<bool> IsTmpfs(const std::string& path); // IsOverlayfs returns true if the file at path is backed by overlayfs. PosixErrorOr<bool> IsOverlayfs(const std::string& path); +PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2); + namespace internal { // Not part of the public API. std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); diff --git a/test/util/logging.cc b/test/util/logging.cc index 5d5e76c46..5fadb076b 100644 --- a/test/util/logging.cc +++ b/test/util/logging.cc @@ -69,9 +69,7 @@ int WriteNumber(int fd, uint32_t val) { } // namespace void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno) { - int saved_errno = errno; - + size_t msg_size, int errno_value) { constexpr char kCheckFailure[] = "Check failed: "; Write(2, kCheckFailure, sizeof(kCheckFailure) - 1); Write(2, cond, cond_size); @@ -81,10 +79,10 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, Write(2, msg, msg_size); } - if (include_errno) { + if (errno_value != 0) { constexpr char kErrnoMessage[] = " (errno "; Write(2, kErrnoMessage, sizeof(kErrnoMessage) - 1); - WriteNumber(2, saved_errno); + WriteNumber(2, errno_value); Write(2, ")", 1); } diff --git a/test/util/logging.h b/test/util/logging.h index 589166fab..9d224ea05 100644 --- a/test/util/logging.h +++ b/test/util/logging.h @@ -21,7 +21,7 @@ namespace gvisor { namespace testing { void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno); + size_t msg_size, int errno_value); // If cond is false, aborts the current process. // @@ -30,7 +30,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, false); \ + 0, 0); \ } \ } while (0) @@ -41,7 +41,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, false); \ + sizeof(msg) - 1, 0); \ } \ } while (0) @@ -52,7 +52,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, true); \ + 0, errno); \ } \ } while (0) @@ -63,10 +63,39 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, true); \ + sizeof(msg) - 1, errno); \ } \ } while (0) +// expr must return PosixErrorOr<T>. The current process is aborted if +// !PosixError<T>.ok(). +// +// This macro is async-signal-safe. +#define TEST_CHECK_NO_ERRNO(expr) \ + ({ \ + auto _expr_result = (expr); \ + if (!_expr_result.ok()) { \ + ::gvisor::testing::CheckFailure( \ + #expr, sizeof(#expr) - 1, nullptr, 0, \ + _expr_result.error().errno_value()); \ + } \ + }) + +// expr must return PosixErrorOr<T>. The current process is aborted if +// !PosixError<T>.ok(). Otherwise, PosixErrorOr<T> value is returned. +// +// This macro is async-signal-safe. +#define TEST_CHECK_NO_ERRNO_AND_VALUE(expr) \ + ({ \ + auto _expr_result = (expr); \ + if (!_expr_result.ok()) { \ + ::gvisor::testing::CheckFailure( \ + #expr, sizeof(#expr) - 1, nullptr, 0, \ + _expr_result.error().errno_value()); \ + } \ + std::move(_expr_result).ValueOrDie(); \ + }) + } // namespace testing } // namespace gvisor diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc index 8b676751b..a6b0de24b 100644 --- a/test/util/multiprocess_util.cc +++ b/test/util/multiprocess_util.cc @@ -154,6 +154,9 @@ PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) { pid_t pid = fork(); if (pid == 0) { fn(); + TEST_CHECK_MSG(!::testing::Test::HasFailure(), + "EXPECT*/ASSERT* failed. These are not async-signal-safe " + "and must not be called from fn."); _exit(0); } MaybeSave(); diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index 2f3bf4a6f..840fde4ee 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -123,7 +123,8 @@ inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, // Calls fn in a forked subprocess and returns the exit status of the // subprocess. // -// fn must be async-signal-safe. +// fn must be async-signal-safe. Use of ASSERT/EXPECT functions is prohibited. +// Use TEST_CHECK variants instead. PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn); } // namespace testing diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc index deed0c05b..8522e4c81 100644 --- a/test/util/posix_error.cc +++ b/test/util/posix_error.cc @@ -50,7 +50,7 @@ std::string PosixError::ToString() const { ret = absl::StrCat("PosixError(errno=", errno_, " ", res, ")"); #endif - if (!msg_.empty()) { + if (strnlen(msg_, sizeof(msg_)) > 0) { ret.append(" "); ret.append(msg_); } diff --git a/test/util/posix_error.h b/test/util/posix_error.h index b634a7f78..27557ad44 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -26,12 +26,18 @@ namespace gvisor { namespace testing { +// PosixError must be async-signal-safe. class ABSL_MUST_USE_RESULT PosixError { public: PosixError() {} + explicit PosixError(int errno_value) : errno_(errno_value) {} - PosixError(int errno_value, std::string msg) - : errno_(errno_value), msg_(std::move(msg)) {} + + PosixError(int errno_value, std::string_view msg) : errno_(errno_value) { + // Check that `msg` will fit, leaving room for '\0' at the end. + TEST_CHECK(msg.size() < sizeof(msg_)); + msg.copy(msg_, msg.size()); + } PosixError(PosixError&& other) = default; PosixError& operator=(PosixError&& other) = default; @@ -45,7 +51,7 @@ class ABSL_MUST_USE_RESULT PosixError { const PosixError& error() const { return *this; } int errno_value() const { return errno_; } - std::string message() const { return msg_; } + const char* message() const { return msg_; } // ToString produces a full string representation of this posix error // including the printable representation of the errno and the error message. @@ -58,7 +64,7 @@ class ABSL_MUST_USE_RESULT PosixError { private: int errno_ = 0; - std::string msg_; + char msg_[1024] = {}; }; template <typename T> diff --git a/tools/bazel.mk b/tools/bazel.mk index 2b20457e9..fb0fc6524 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -191,7 +191,7 @@ endif build_paths = \ (set -euo pipefail; \ $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) 2>&1 \ - | tee /proc/self/fd/2 \ + | tee /dev/fd/2 \ | sed -n -e '/^Target/,$$p' \ | sed -n -e '/^ \($(subst /,\/,$(subst $(SPACE),\|,$(BUILD_ROOTS)))\)/p' \ | sed -e 's/ /\n/g' \ diff --git a/website/cmd/syscalldocs/main.go b/website/cmd/syscalldocs/main.go index 327537214..830d2bac7 100644 --- a/website/cmd/syscalldocs/main.go +++ b/website/cmd/syscalldocs/main.go @@ -52,6 +52,7 @@ layout: docs category: Compatibility weight: 50 permalink: /docs/user_guide/compatibility/{{.OS}}/{{.Arch}}/ +include_in_menu: True --- This table is a reference of {{.OS}} syscalls for the {{.Arch}} architecture and @@ -75,9 +76,9 @@ syscalls. {{if .Undocumented}}{{.Undocumented}} syscalls are not yet documented. </thead> <tbody> {{range $i, $syscall := .Syscalls}} - <tr> - <td><a class="doc-table-anchor" id="{{.Name}}"></a>{{.Number}}</td> - <td><a href="http://man7.org/linux/man-pages/man2/{{.Name}}.2.html" target="_blank" rel="noopener">{{.Name}}</a></td> + <tr id="{{.Name}}"> + <td><a href="#{{.Name}}">{{.Number}}</a></td> + <td><a href="{{.DocURL}}" target="_blank" rel="noopener">{{.Name}}</a></td> <td>{{.Support}}</td> <td>{{.Note}} {{range $i, $url := .URLs}}<br/>See: <a href="{{.}}">{{.}}</a>{{end}}</td> </tr> @@ -92,6 +93,27 @@ func Fatalf(format string, a ...interface{}) { os.Exit(1) } +// syscallDocURL returns a doc url for a given syscall, doing its best to return a url that exists. +func syscallDocURL(name string) string { + customDocs := map[string]string{ + "io_pgetevents": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "rseq": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "io_uring_setup": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_setup.2.en.html", + "io_uring_enter": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_enter.2.en.html", + "io_uring_register": "https://manpages.debian.org/buster-backports/liburing-dev/io_uring_register.2.en.html", + "open_tree": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "move_mount": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "fsopen": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "fsconfig": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "fsmount": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + "fspick": "https://man7.org/linux/man-pages/man2/syscalls.2.html", + } + if url, ok := customDocs[name]; ok { + return url + } + return fmt.Sprintf("http://man7.org/linux/man-pages/man2/%s.2.html", name) +} + func main() { inputFlag := flag.String("in", "-", "File to input ('-' for stdin)") outputDir := flag.String("out", ".", "Directory to output files.") @@ -145,6 +167,7 @@ func main() { Syscalls []struct { Name string Number uintptr + DocURL string Support string Note string URLs []string @@ -161,6 +184,7 @@ func main() { Syscalls: []struct { Name string Number uintptr + DocURL string Support string Note string URLs []string @@ -187,14 +211,16 @@ func main() { data.Syscalls = append(data.Syscalls, struct { Name string Number uintptr + DocURL string Support string Note string URLs []string }{ Name: s.Name, Number: num, + DocURL: syscallDocURL(s.Name), Support: s.Support, - Note: s.Note, // TODO urls + Note: s.Note, URLs: s.URLs, }) } |