diff options
100 files changed, 1776 insertions, 715 deletions
diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD index f405349b3..a862c76f4 100644 --- a/g3doc/user_guide/tutorials/BUILD +++ b/g3doc/user_guide/tutorials/BUILD @@ -37,10 +37,19 @@ doc( ) doc( + name = "knative", + src = "knative.md", + category = "User Guide", + permalink = "/docs/tutorials/knative/", + subcategory = "Tutorials", + weight = "40", +) + +doc( name = "cni", src = "cni.md", category = "User Guide", permalink = "/docs/tutorials/cni/", subcategory = "Tutorials", - weight = "40", + weight = "50", ) diff --git a/g3doc/user_guide/tutorials/knative.md b/g3doc/user_guide/tutorials/knative.md new file mode 100644 index 000000000..3f5207fcc --- /dev/null +++ b/g3doc/user_guide/tutorials/knative.md @@ -0,0 +1,88 @@ +# Knative Services + +[Knative](https://knative.dev/) is a platform for running serverless workloads +on Kubernetes. This guide will show you how to run basic Knative workloads in +gVisor. + +## Prerequisites + +This guide assumes you have have a cluster that is capable of running gVisor +workloads. This could be a +[GKE Sandbox](https://cloud.google.com/kubernetes-engine/sandbox/) enabled +cluster on Google Cloud Platform or one you have set up yourself using +[containerd Quick Start](https://gvisor.dev/docs/user_guide/containerd/quick_start/). + +This guide will also assume you have Knative installed using +[Istio](https://istio.io/) as the network layer. You can follow the +[Knative installation guide](https://knative.dev/docs/install/install-serving-with-yaml/) +to install Knative. + +## Enable the RuntimeClass feature flag + +Knative allows the use of various parameters on Pods via +[feature flags](https://knative.dev/docs/serving/feature-flags/). We will enable +the +[runtimeClassName](https://knative.dev/docs/serving/feature-flags/#kubernetes-runtime-class) +feature flag to enable the use of the Kubernetes +[Runtime Class](https://kubernetes.io/docs/concepts/containers/runtime-class/). + +Edit the feature flags ConfigMap. + +```bash +kubectl edit configmap config-features -n knative-serving +``` + +Add the `kubernetes.podspec-runtimeclassname: enabled` to the `data` field. Once +you are finished the ConfigMap will look something like this (minus all the +system fields). + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: config-features + namespace: knative-serving + labels: + serving.knative.dev/release: v0.22.0 +data: + kubernetes.podspec-runtimeclassname: enabled +``` + +## Deploy the Service + +After you have set the Runtime Class feature flag you can now create Knative +services that specify a `runtimeClassName` in the spec. + +```bash +cat <<EOF | kubectl apply -f - +apiVersion: serving.knative.dev/v1 +kind: Service +metadata: + name: helloworld-go +spec: + template: + spec: + runtimeClassName: gvisor + containers: + - image: gcr.io/knative-samples/helloworld-go + env: + - name: TARGET + value: "gVisor User" +EOF +``` + +You can see the pods running and their Runtime Class. + +```bash +kubectl get pods -o=custom-columns='NAME:.metadata.name,RUNTIME CLASS:.spec.runtimeClassName,STATUS:.status.phase' +``` + +Output should look something like the following. Note that your service might +scale to zero. If you access it via it's URL you should get a new Pod. + +``` +NAME RUNTIME CLASS STATUS +helloworld-go-00002-deployment-646c87b7f5-5v68s gvisor Running +``` + +Congrats! Your Knative service is now running in gVisor! diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a461bb65e..29ead20d0 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -15,6 +15,7 @@ go_library( "bpf.go", "capability.go", "clone.go", + "context.go", "dev.go", "elf.go", "epoll.go", @@ -77,6 +78,7 @@ go_library( deps = [ "//pkg/abi", "//pkg/bits", + "//pkg/context", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go new file mode 100644 index 000000000..d2dbba183 --- /dev/null +++ b/pkg/abi/linux/context.go @@ -0,0 +1,36 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// contextID is the linux package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals. + CtxSignalNoInfoFunc contextID = iota +) + +// SignalNoInfoFuncFromContext returns a callback function that can be used to send a +// signal to the given context. +func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error { + if f := ctx.Value(CtxSignalNoInfoFunc); f != nil { + return f.(func(Signal) error) + } + return nil +} diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 74a55a123..514592b08 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -22,11 +22,11 @@ import ( // EcdsaVerify verifies the signature in r, s of hash using ECDSA and the // public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { - return ecdsa.Verify(pub, hash, r, s) +func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { + return ecdsa.Verify(pub, hash, r, s), nil } // SumSha384 returns the SHA384 checksum of the data. -func SumSha384(data []byte) (sum384 [sha512.Size384]byte) { - return sha512.Sum384(data) +func SumSha384(data []byte) ([sha512.Size384]byte, error) { + return sha512.Sum384(data), nil } diff --git a/pkg/linuxerr/BUILD b/pkg/linuxerr/BUILD new file mode 100644 index 000000000..c5abbd34f --- /dev/null +++ b/pkg/linuxerr/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "linuxerr", + srcs = ["linuxerr.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "linuxerr_test", + srcs = ["linuxerr_test.go"], + deps = [ + ":linuxerr", + "//pkg/syserror", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/linuxerr/linuxerr.go b/pkg/linuxerr/linuxerr.go new file mode 100644 index 000000000..f45caaadf --- /dev/null +++ b/pkg/linuxerr/linuxerr.go @@ -0,0 +1,184 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package linuxerr contains syscall error codes exported as an error interface +// pointers. This allows for fast comparison and return operations comperable +// to unix.Errno constants. +package linuxerr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// Error represents a syscall errno with a descriptive message. +type Error struct { + errno linux.Errno + message string +} + +func new(err linux.Errno, message string) *Error { + return &Error{ + errno: err, + message: message, + } +} + +// Error implements error.Error. +func (e *Error) Error() string { return e.message } + +// Errno returns the underlying linux.Errno value. +func (e *Error) Errno() linux.Errno { return e.errno } + +// The following varables have the same meaning as their errno equivalent. + +// Errno values from include/uapi/asm-generic/errno-base.h. +var ( + EPERM = new(linux.EPERM, "operation not permitted") + ENOENT = new(linux.ENOENT, "no such file or directory") + ESRCH = new(linux.ESRCH, "no such process") + EINTR = new(linux.EINTR, "interrupted system call") + EIO = new(linux.EIO, "I/O error") + ENXIO = new(linux.ENXIO, "no such device or address") + E2BIG = new(linux.E2BIG, "argument list too long") + ENOEXEC = new(linux.ENOEXEC, "exec format error") + EBADF = new(linux.EBADF, "bad file number") + ECHILD = new(linux.ECHILD, "no child processes") + EAGAIN = new(linux.EAGAIN, "try again") + ENOMEM = new(linux.ENOMEM, "out of memory") + EACCES = new(linux.EACCES, "permission denied") + EFAULT = new(linux.EFAULT, "bad address") + ENOTBLK = new(linux.ENOTBLK, "block device required") + EBUSY = new(linux.EBUSY, "device or resource busy") + EEXIST = new(linux.EEXIST, "file exists") + EXDEV = new(linux.EXDEV, "cross-device link") + ENODEV = new(linux.ENODEV, "no such device") + ENOTDIR = new(linux.ENOTDIR, "not a directory") + EISDIR = new(linux.EISDIR, "is a directory") + EINVAL = new(linux.EINVAL, "invalid argument") + ENFILE = new(linux.ENFILE, "file table overflow") + EMFILE = new(linux.EMFILE, "too many open files") + ENOTTY = new(linux.ENOTTY, "not a typewriter") + ETXTBSY = new(linux.ETXTBSY, "text file busy") + EFBIG = new(linux.EFBIG, "file too large") + ENOSPC = new(linux.ENOSPC, "no space left on device") + ESPIPE = new(linux.ESPIPE, "illegal seek") + EROFS = new(linux.EROFS, "read-only file system") + EMLINK = new(linux.EMLINK, "too many links") + EPIPE = new(linux.EPIPE, "broken pipe") + EDOM = new(linux.EDOM, "math argument out of domain of func") + ERANGE = new(linux.ERANGE, "math result not representable") +) + +// Errno values from include/uapi/asm-generic/errno.h. +var ( + EDEADLK = new(linux.EDEADLK, "resource deadlock would occur") + ENAMETOOLONG = new(linux.ENAMETOOLONG, "file name too long") + ENOLCK = new(linux.ENOLCK, "no record locks available") + ENOSYS = new(linux.ENOSYS, "invalid system call number") + ENOTEMPTY = new(linux.ENOTEMPTY, "directory not empty") + ELOOP = new(linux.ELOOP, "too many symbolic links encountered") + EWOULDBLOCK = new(linux.EWOULDBLOCK, "operation would block") + ENOMSG = new(linux.ENOMSG, "no message of desired type") + EIDRM = new(linux.EIDRM, "identifier removed") + ECHRNG = new(linux.ECHRNG, "channel number out of range") + EL2NSYNC = new(linux.EL2NSYNC, "level 2 not synchronized") + EL3HLT = new(linux.EL3HLT, "level 3 halted") + EL3RST = new(linux.EL3RST, "level 3 reset") + ELNRNG = new(linux.ELNRNG, "link number out of range") + EUNATCH = new(linux.EUNATCH, "protocol driver not attached") + ENOCSI = new(linux.ENOCSI, "no CSI structure available") + EL2HLT = new(linux.EL2HLT, "level 2 halted") + EBADE = new(linux.EBADE, "invalid exchange") + EBADR = new(linux.EBADR, "invalid request descriptor") + EXFULL = new(linux.EXFULL, "exchange full") + ENOANO = new(linux.ENOANO, "no anode") + EBADRQC = new(linux.EBADRQC, "invalid request code") + EBADSLT = new(linux.EBADSLT, "invalid slot") + EDEADLOCK = new(linux.EDEADLOCK, EDEADLK.message) + EBFONT = new(linux.EBFONT, "bad font file format") + ENOSTR = new(linux.ENOSTR, "device not a stream") + ENODATA = new(linux.ENODATA, "no data available") + ETIME = new(linux.ETIME, "timer expired") + ENOSR = new(linux.ENOSR, "out of streams resources") + ENONET = new(linux.ENOENT, "machine is not on the network") + ENOPKG = new(linux.ENOPKG, "package not installed") + EREMOTE = new(linux.EREMOTE, "object is remote") + ENOLINK = new(linux.ENOLINK, "link has been severed") + EADV = new(linux.EADV, "advertise error") + ESRMNT = new(linux.ESRMNT, "srmount error") + ECOMM = new(linux.ECOMM, "communication error on send") + EPROTO = new(linux.EPROTO, "protocol error") + EMULTIHOP = new(linux.EMULTIHOP, "multihop attempted") + EDOTDOT = new(linux.EDOTDOT, "RFS specific error") + EBADMSG = new(linux.EBADMSG, "not a data message") + EOVERFLOW = new(linux.EOVERFLOW, "value too large for defined data type") + ENOTUNIQ = new(linux.ENOTUNIQ, "name not unique on network") + EBADFD = new(linux.EBADFD, "file descriptor in bad state") + EREMCHG = new(linux.EREMCHG, "remote address changed") + ELIBACC = new(linux.ELIBACC, "can not access a needed shared library") + ELIBBAD = new(linux.ELIBBAD, "accessing a corrupted shared library") + ELIBSCN = new(linux.ELIBSCN, ".lib section in a.out corrupted") + ELIBMAX = new(linux.ELIBMAX, "attempting to link in too many shared libraries") + ELIBEXEC = new(linux.ELIBEXEC, "cannot exec a shared library directly") + EILSEQ = new(linux.EILSEQ, "illegal byte sequence") + ERESTART = new(linux.ERESTART, "interrupted system call should be restarted") + ESTRPIPE = new(linux.ESTRPIPE, "streams pipe error") + EUSERS = new(linux.EUSERS, "too many users") + ENOTSOCK = new(linux.ENOTSOCK, "socket operation on non-socket") + EDESTADDRREQ = new(linux.EDESTADDRREQ, "destination address required") + EMSGSIZE = new(linux.EMSGSIZE, "message too long") + EPROTOTYPE = new(linux.EPROTOTYPE, "protocol wrong type for socket") + ENOPROTOOPT = new(linux.ENOPROTOOPT, "protocol not available") + EPROTONOSUPPORT = new(linux.EPROTONOSUPPORT, "protocol not supported") + ESOCKTNOSUPPORT = new(linux.ESOCKTNOSUPPORT, "socket type not supported") + EOPNOTSUPP = new(linux.EOPNOTSUPP, "operation not supported on transport endpoint") + EPFNOSUPPORT = new(linux.EPFNOSUPPORT, "protocol family not supported") + EAFNOSUPPORT = new(linux.EAFNOSUPPORT, "address family not supported by protocol") + EADDRINUSE = new(linux.EADDRINUSE, "address already in use") + EADDRNOTAVAIL = new(linux.EADDRNOTAVAIL, "cannot assign requested address") + ENETDOWN = new(linux.ENETDOWN, "network is down") + ENETUNREACH = new(linux.ENETUNREACH, "network is unreachable") + ENETRESET = new(linux.ENETRESET, "network dropped connection because of reset") + ECONNABORTED = new(linux.ECONNABORTED, "software caused connection abort") + ECONNRESET = new(linux.ECONNRESET, "connection reset by peer") + ENOBUFS = new(linux.ENOBUFS, "no buffer space available") + EISCONN = new(linux.EISCONN, "transport endpoint is already connected") + ENOTCONN = new(linux.ENOTCONN, "transport endpoint is not connected") + ESHUTDOWN = new(linux.ESHUTDOWN, "cannot send after transport endpoint shutdown") + ETOOMANYREFS = new(linux.ETOOMANYREFS, "too many references: cannot splice") + ETIMEDOUT = new(linux.ETIMEDOUT, "connection timed out") + ECONNREFUSED = new(linux.ECONNREFUSED, "connection refused") + EHOSTDOWN = new(linux.EHOSTDOWN, "host is down") + EHOSTUNREACH = new(linux.EHOSTUNREACH, "no route to host") + EALREADY = new(linux.EALREADY, "operation already in progress") + EINPROGRESS = new(linux.EINPROGRESS, "operation now in progress") + ESTALE = new(linux.ESTALE, "stale file handle") + EUCLEAN = new(linux.EUCLEAN, "structure needs cleaning") + ENOTNAM = new(linux.ENOTNAM, "not a XENIX named type file") + ENAVAIL = new(linux.ENAVAIL, "no XENIX semaphores available") + EISNAM = new(linux.EISNAM, "is a named type file") + EREMOTEIO = new(linux.EREMOTEIO, "remote I/O error") + EDQUOT = new(linux.EDQUOT, "quota exceeded") + ENOMEDIUM = new(linux.ENOMEDIUM, "no medium found") + EMEDIUMTYPE = new(linux.EMEDIUMTYPE, "wrong medium type") + ECANCELED = new(linux.ECANCELED, "operation Canceled") + ENOKEY = new(linux.ENOKEY, "required key not available") + EKEYEXPIRED = new(linux.EKEYEXPIRED, "key has expired") + EKEYREVOKED = new(linux.EKEYREVOKED, "key has been revoked") + EKEYREJECTED = new(linux.EKEYREJECTED, "key was rejected by service") + EOWNERDEAD = new(linux.EOWNERDEAD, "owner died") + ENOTRECOVERABLE = new(linux.ENOTRECOVERABLE, "state not recoverable") + ERFKILL = new(linux.ERFKILL, "operation not possible due to RF-kill") + EHWPOISON = new(linux.EHWPOISON, "memory page has hardware error") +) diff --git a/pkg/syserror/syserror_test.go b/pkg/linuxerr/linuxerr_test.go index c141e5f6e..d34937e93 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/linuxerr/linuxerr_test.go @@ -19,6 +19,7 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -30,7 +31,13 @@ func BenchmarkAssignErrno(b *testing.B) { } } -func BenchmarkAssignError(b *testing.B) { +func BenchmarkLinuxerrAssignError(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = linuxerr.EINVAL + } +} + +func BenchmarkAssignSyserrorError(b *testing.B) { for i := b.N; i > 0; i-- { globalError = syserror.EINVAL } @@ -46,7 +53,17 @@ func BenchmarkCompareErrno(b *testing.B) { } } -func BenchmarkCompareError(b *testing.B) { +func BenchmarkCompareLinuxerrError(b *testing.B) { + globalError = linuxerr.E2BIG + j := 0 + for i := b.N; i > 0; i-- { + if globalError == linuxerr.EINVAL { + j++ + } + } +} + +func BenchmarkCompareSyserrorError(b *testing.B) { globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { @@ -62,7 +79,7 @@ func BenchmarkSwitchErrno(b *testing.B) { for i := b.N; i > 0; i-- { switch globalError { case unix.EINVAL: - j += 1 + j++ case unix.EINTR: j += 2 case unix.EAGAIN: @@ -71,13 +88,28 @@ func BenchmarkSwitchErrno(b *testing.B) { } } -func BenchmarkSwitchError(b *testing.B) { +func BenchmarkSwitchLinuxerrError(b *testing.B) { + globalError = linuxerr.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case linuxerr.EINVAL: + j++ + case linuxerr.EINTR: + j += 2 + case linuxerr.EAGAIN: + j += 3 + } + } +} + +func BenchmarkSwitchSyserrorError(b *testing.B) { globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { case syserror.EINVAL: - j += 1 + j++ case syserror.EINTR: j += 2 case syserror.EAGAIN: diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index bcdb2dda2..819e140bc 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -92,7 +92,6 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") log.Warningf("Opened a writable executable: %q", name) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 0a954c138..eed05e369 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -60,7 +60,6 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, return nil, err } if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if atomic.LoadInt32(&d.mmapFD) >= 0 { diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index dc019ebd5..c12444b7e 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -101,7 +101,6 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci d.fs.specialFileFDs[fd] = struct{}{} d.fs.syncMu.Unlock() if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if h.fd >= 0 { diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go index 7e535b527..17d0d5025 100644 --- a/pkg/sentry/fsmetric/fsmetric.go +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -42,7 +42,6 @@ var ( // Metrics that only apply to fs/gofer and fsimpl/gofer. var ( - GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index 0fbf27f64..c93ef6ac1 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -181,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files for _, h := range r.hierarchies { if h.match(ctypes) { - h.fs.IncRef() + if !h.fs.TryIncRef() { + // Racing with filesystem destruction, namely h.fs.Release. + // Since we hold r.mu, we know the hierarchy hasn't been + // unregistered yet, but its associated filesystem is tearing + // down. + // + // If we simply indicate the hierarchy wasn't found without + // cleaning up the registry, the caller can race with the + // unregister and find itself temporarily unable to create a new + // hierarchy with a subset of the relevant controllers. + // + // To keep the result of FindHierarchy consistent with the + // uniqueness of controllers enforced by Register, drop the + // dying hierarchy now. The eventual unregister by the FS + // teardown will become a no-op. + return nil + } return h.fs } } @@ -230,12 +246,17 @@ func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error { return nil } -// Unregister removes a previously registered hierarchy from the registry. If -// the controller was not previously registered, Unregister is a no-op. +// Unregister removes a previously registered hierarchy from the registry. If no +// such hierarchy is registered, Unregister is a no-op. func (r *CgroupRegistry) Unregister(hid uint32) { r.mu.Lock() - defer r.mu.Unlock() + r.unregisterLocked(hid) + r.mu.Unlock() +} +// Precondition: Caller must hold r.mu. +// +checklocks:r.mu +func (r *CgroupRegistry) unregisterLocked(hid uint32) { if h, ok := r.hierarchies[hid]; ok { for name, _ := range h.controllers { delete(r.controllers, name) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 2d89b9ccd..24e467e93 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) if n > 0 { p.Notify(waiter.ReadableEvents) } + if err == unix.EPIPE { + // If we are returning EPIPE send SIGPIPE to the task. + if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil { + sendSig(linux.SIGPIPE) + } + } return n, err } diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index fe2ab1662..3c5bd8ff7 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -702,7 +702,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool { return s.checkCapability(creds) } -// destroy destroys the set. Caller must hold 's.mu'. +// destroy destroys the set. +// +// Preconditions: Caller must hold 's.mu'. func (s *Set) destroy() { // Notify all waiters. They will fail on the next attempt to execute // operations and return error. diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 70b0699dc..c82d9e82b 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -17,6 +17,7 @@ package kernel import ( "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t.k.RealtimeClock() case limits.CtxLimits: return t.tg.limits + case linux.CtxSignalNoInfoFunc: + return func(sig linux.Signal) error { + return t.SendSignal(SignalInfoNoInfo(sig, t, t)) + } case pgalloc.CtxMemoryFile: return t.k.mf case pgalloc.CtxMemoryFileProvider: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0b64a24c3..037ccfec8 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -77,41 +77,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter { // Metrics contains metrics exported by netstack. var Metrics = tcpip.Stats{ - UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."), - MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), - DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), + DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."), + NICs: tcpip.NICStats{ + UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."), + UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."), + MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."), + Tx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."), + Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."), + }, + Rx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."), + Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."), + }, + DisabledRx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."), + Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."), + }, + Neighbor: tcpip.NICNeighborStats{ + UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."), + }, + }, ICMP: tcpip.ICMPStats{ V4: tcpip.ICMPv4Stats{ PacketsSent: tcpip.ICMPv4SentPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."), }, Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."), }, @@ -119,40 +137,40 @@ var Metrics = tcpip.Stats{ V6: tcpip.ICMPv6Stats{ PacketsSent: tcpip.ICMPv6SentPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."), Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."), @@ -163,23 +181,23 @@ var Metrics = tcpip.Stats{ IGMP: tcpip.IGMPStats{ PacketsSent: tcpip.IGMPSentPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."), }, - Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."), }, PacketsReceived: tcpip.IGMPReceivedPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."), }, - Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."), ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."), - Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."), }, }, IP: tcpip.IPStats{ @@ -205,7 +223,8 @@ var Metrics = tcpip.Stats{ LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), - PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."), + PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."), + HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."), Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), }, }, @@ -1126,7 +1145,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. - info := linux.TCPInfo{} + info := linux.TCPInfo{ + State: uint8(v.State), + RTO: uint32(v.RTO / time.Microsecond), + RTT: uint32(v.RTT / time.Microsecond), + RTTVar: uint32(v.RTTVar / time.Microsecond), + SndSsthresh: v.SndSsthresh, + SndCwnd: v.SndCwnd, + } switch v.CcState { case tcpip.RTORecovery: info.CaState = linux.TCP_CA_Loss @@ -1137,11 +1163,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, case tcpip.Open: info.CaState = linux.TCP_CA_Open } - info.RTO = uint32(v.RTO / time.Microsecond) - info.RTT = uint32(v.RTT / time.Microsecond) - info.RTTVar = uint32(v.RTTVar / time.Microsecond) - info.SndSsthresh = v.SndSsthresh - info.SndCwnd = v.SndCwnd // In netstack reorderSeen is updated only when RACK is enabled. // We only track whether the reordering is seen, which is diff --git a/pkg/sync/README.md b/pkg/sync/README.md index 2183c4e20..be1a01f08 100644 --- a/pkg/sync/README.md +++ b/pkg/sync/README.md @@ -1,4 +1,4 @@ -# Syncutil +# sync This package provides additional synchronization primitives not provided by the Go stdlib 'sync' package. It is partially derived from the upstream 'sync' diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index 7d2f5adf6..76bee5a64 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,12 +8,3 @@ go_library( visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index ea46c30da..ed4d7e958 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -39,12 +39,30 @@ go_library( deps_test( name = "netstack_deps_test", allowed = [ + # gVisor deps. + "//pkg/atomicbitops", + "//pkg/buffer", + "//pkg/context", + "//pkg/gohacks", + "//pkg/goid", + "//pkg/ilist", + "//pkg/iovec", + "//pkg/linewriter", + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/state", + "//pkg/state/wire", + "//pkg/sync", + "//pkg/waiter", + + # Other deps. "@com_github_google_btree//:go_default_library", "@org_golang_x_sys//unix:go_default_library", "@org_golang_x_time//rate:go_default_library", ], allowed_prefixes = [ - "//", + "//pkg/tcpip", "@org_golang_x_sys//internal/unsafeheader", ], targets = [ diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 18e6cc3cd..bab640faf 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..c4dbe8440 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -218,6 +218,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..d6cad3a94 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -451,7 +451,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index ef9126deb..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..f89b55561 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -396,7 +396,7 @@ func TestDirectRequest(t *testing.T) { nicID: nicID, entry: stack.NeighborEntry{ Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + LinkAddr: test.senderLinkAddr, State: stack.Stale, }, } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 0c2b62127..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig tcpip.MultiCounterStat + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) } // LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index cec3e62c4..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1a82b584..5f6b0c6af 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -222,7 +222,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -481,6 +480,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool { return true } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -537,7 +552,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -653,6 +673,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4NetUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable case *icmpReasonFragmentationNeeded: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 23178277a..bb8d53c12 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -104,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -898,7 +908,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() case *ip.ErrParameterProblem: - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() case *ip.ErrMessageTooLong: stats.ip.Forwarding.PacketTooBig.Increment() @@ -935,7 +944,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -1008,7 +1016,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index da9cc0ae8..7a8e0aa24 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,10 +36,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -112,10 +112,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -type forwardedPacket struct { - fragments []fragmentInfo -} - func TestForwarding(t *testing.T) { const ( incomingNICID = 1 @@ -134,11 +130,11 @@ func TestForwarding(t *testing.T) { PrefixLen: 8, } outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") - multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { name string @@ -453,7 +449,7 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv4Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), @@ -461,7 +457,7 @@ func TestForwarding(t *testing.T) { checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) } else if ok { @@ -470,7 +466,7 @@ func TestForwarding(t *testing.T) { if test.expectPacketForwarded { if len(test.expectedFragmentsForwarded) != 0 { - fragmentedPackets := []*stack.PacketBuffer{} + var fragmentedPackets []*stack.PacketBuffer for i := 0; i < len(test.expectedFragmentsForwarded); i++ { reply, ok = outgoingEndpoint.Read() if !ok { @@ -487,7 +483,7 @@ func TestForwarding(t *testing.T) { // maximum IP header size and the maximum size allocated for link layer // headers. In this case, no size is allocated for link layer headers. expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { t.Error(err) } } else { @@ -496,7 +492,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo packet through outgoing NIC") } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -1315,7 +1311,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1334,7 +1330,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1546,9 +1542,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1602,7 +1598,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1612,13 +1608,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1726,8 +1722,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2301,7 +2297,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2705,7 +2701,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2948,7 +2944,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -3074,7 +3070,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3244,8 +3240,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3253,7 +3249,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 307e1972d..23fc94303 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -1029,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool { return false } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + // icmpReasonFragmentationNeeded is an error where a packet is to big to be sent // out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. type icmpReasonPacketTooBig struct{} @@ -1143,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1222,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable case *icmpReasonPacketTooBig: icmpHdr.SetType(header.ICMPv6PacketTooBig) icmpHdr.SetCode(header.ICMPv6UnusedCode) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 040cd4bc8..d7b04554e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -21,7 +21,6 @@ import ( "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -1309,7 +1304,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95e11ac51..68f8308f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -282,6 +282,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index f0ff111c5..11ff36561 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1754,7 +1754,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 234e34952..b300ed894 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -163,11 +111,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b26936b7f..b7c2de652 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -222,7 +222,7 @@ type SocketOptions struct { getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize int64 + receiveBufferSize atomicbitops.AlignedAtomicInt64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -653,13 +653,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return atomic.LoadInt64(&so.receiveBufferSize) + return so.receiveBufferSize.Load() } // SetReceiveBufferSize sets value for SO_RCVBUF option. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { if !notify { - atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + so.receiveBufferSize.Store(receiveBufferSize) return } @@ -684,8 +684,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo v = math.MaxInt32 } - oldSz := atomic.LoadInt64(&so.receiveBufferSize) + oldSz := so.receiveBufferSize.Load() // Notify endpoint about change in buffer size. newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - atomic.StoreInt64(&so.receiveBufferSize, newSz) + so.receiveBufferSize.Store(newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..d971db010 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -114,10 +114,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +130,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -319,7 +315,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,7 +350,7 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..90881169d 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -1068,7 +1060,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..463d017fc 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -307,7 +307,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -361,7 +361,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +378,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..c2a291244 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..378389db2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -807,20 +794,20 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +839,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..02f905351 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -323,26 +323,21 @@ type Rand interface { type NUDState struct { rng Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex - - config NUDConfigurations - - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration - - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + mu struct { + sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified @@ -351,7 +346,7 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { s := &NUDState{ rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +354,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +372,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if time.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +403,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = time.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..6ba97d626 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -28,17 +28,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 01652fbe7..4ca702121 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..72760a4a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -40,13 +40,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -804,7 +797,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +849,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..73e0f0d58 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + nicStats := s.NICInfo()[nicid].Stats + + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -3837,8 +3874,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3910,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4256,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 797778e08..34f820053 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -861,6 +861,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +900,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1552,6 +1558,10 @@ type IPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig *StatCounter + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -1835,37 +1845,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter - - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..8802f36b2 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -48,17 +48,20 @@ go_test( size = "small", srcs = ["link_resolution_test.go"], deps = [ + "//pkg/context", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..9f727eb8f 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,22 +17,26 @@ package link_resolution_test import ( "bytes" "fmt" + "net" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -395,6 +399,246 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + var request channel.PacketInfo + var ok bool + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + if request, ok = outgoingEndpoint.Read(); !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + // + // TODO(gvisor.dev/issue/6012): Replace with asynchronous read after we + // have integrated the stack clock with the dequeuing code. + reply, ok := incomingEndpoint.ReadContext(context.Background()) + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index f84d399fb..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er return nil } + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1633d0aeb..4ddb7020d 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -25,7 +25,6 @@ import ( const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8afde7fca..517903ae7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -160,7 +160,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -349,7 +349,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +390,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,7 +606,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..fa703a0ed 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -220,19 +220,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -318,7 +318,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +339,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -484,7 +484,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..07a585444 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -402,7 +402,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +428,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +439,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +513,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5e03e7715..05b41e0f8 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1235,7 +1235,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..0ca986512 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -177,7 +177,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans s := newIncomingSegment(id, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +184,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f148d505d..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 50d39cbad..fb7670adb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -422,12 +434,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -747,7 +759,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,12 +776,12 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. @@ -806,11 +818,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -1956,6 +1968,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2731,7 +2748,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..a56d34dc5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,7 +154,7 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) @@ -167,7 +167,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,11 +281,11 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..d43e21426 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..813a6dffd 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 133371455..4fd8a0624 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -137,9 +137,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7e5ba6ef7..61754de29 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -243,7 +243,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -262,5 +262,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index f43e86677..6b7293755 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..29af87b7d 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9916182e3..e7ede7662 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4259,7 +4259,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5335,7 +5335,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5360,12 +5360,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5507,7 +5502,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5884,7 +5879,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -6118,7 +6113,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6375,7 +6370,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6429,7 +6424,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6484,14 +6479,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6619,19 +6614,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -7042,7 +7034,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -7467,7 +7459,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7545,7 +7537,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..623e069a6 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -40,13 +40,14 @@ type udpPacket struct { } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed @@ -98,7 +99,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +177,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,12 +205,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -290,7 +291,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -801,8 +802,8 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -1301,12 +1302,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..2e283e52b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1462,7 +1462,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1474,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1486,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1504,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -2115,8 +2115,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2124,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 8d31e33b2..29e202b7d 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -131,8 +131,9 @@ func New(conf *config.Config, args *Args) (*Sandbox, error) { // The Cleanup object cleans up partially created sandboxes when an error // occurs. Any errors occurring during cleanup itself are ignored. c := cleanup.Make(func() { - err := s.destroy() - log.Warningf("error destroying sandbox: %v", err) + if err := s.destroy(); err != nil { + log.Warningf("error destroying sandbox: %v", err) + } }) defer c.Clean() diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 616215dc3..d8059ab98 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -16,6 +16,8 @@ go_library( ], visibility = ["//test/packetimpact:__subpackages__"], deps = [ + "//pkg/abi/linux", + "//pkg/binary", "//pkg/hostarch", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index eabdc8cb3..269e163bb 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -22,11 +22,13 @@ import ( "testing" "time" - pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" - "golang.org/x/sys/unix" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "gvisor.dev/gvisor/pkg/abi/linux" + bin "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" ) // DUT communicates with the DUT to force it to make POSIX calls. @@ -428,6 +430,33 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, so return ret, timeval, errno } +// GetSockOptTCPInfo retreives TCPInfo for the given socket descriptor. +func (dut *DUT) GetSockOptTCPInfo(t *testing.T, sockfd int32) linux.TCPInfo { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, info, err := dut.GetSockOptTCPInfoWithErrno(ctx, t, sockfd) + if ret != 0 || err != unix.Errno(0) { + t.Fatalf("failed to GetSockOptTCPInfo: %s", err) + } + return info +} + +// GetSockOptTCPInfoWithErrno retreives TCPInfo with any errno. +func (dut *DUT) GetSockOptTCPInfoWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, linux.TCPInfo, error) { + t.Helper() + + info := linux.TCPInfo{} + ret, infoBytes, errno := dut.GetSockOptWithErrno(ctx, t, sockfd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + bin.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + + return ret, info, errno +} + // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index c4fe293e0..b1d280f98 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -104,8 +104,6 @@ packetimpact_testbench( srcs = ["tcp_retransmits_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", @@ -189,6 +187,7 @@ packetimpact_testbench( name = "tcp_synsent_reset", srcs = ["tcp_synsent_reset_test.go"], deps = [ + "//pkg/abi/linux", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", @@ -353,8 +352,6 @@ packetimpact_testbench( srcs = ["tcp_rack_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", "//test/packetimpact/testbench", @@ -367,8 +364,6 @@ packetimpact_testbench( srcs = ["tcp_info_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go index 93f58ec49..b7514e846 100644 --- a/test/packetimpact/tests/tcp_info_test.go +++ b/test/packetimpact/tests/tcp_info_test.go @@ -21,8 +21,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -53,13 +51,10 @@ func TestTCPInfo(t *testing.T) { } conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + info := dut.GetSockOptTCPInfo(t, acceptFD) + if got, want := uint32(info.State), linux.TCP_ESTABLISHED; got != want { + t.Fatalf("got %d want %d", got, want) } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) - rtt := time.Duration(info.RTT) * time.Microsecond rttvar := time.Duration(info.RTTVar) * time.Microsecond rto := time.Duration(info.RTO) * time.Microsecond @@ -94,12 +89,7 @@ func TestTCPInfo(t *testing.T) { t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) } - info = linux.TCPInfo{} - infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info = dut.GetSockOptTCPInfo(t, acceptFD) if info.CaState != linux.TCP_CA_Loss { t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss) } diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go index ff1431bbf..5a60bf712 100644 --- a/test/packetimpact/tests/tcp_rack_test.go +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -21,8 +21,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -69,12 +67,7 @@ func closeSACKConnection(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4 } func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto time.Duration) { - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond } @@ -402,12 +395,7 @@ func TestRACKWithLostRetransmission(t *testing.T) { } // Check the congestion control state. - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) if info.CaState != linux.TCP_CA_Recovery { t.Fatalf("expected connection to be in fast recovery, want: %v got: %v", linux.TCP_CA_Recovery, info.CaState) } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index 1eafe20c3..d3fb789f4 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -21,9 +21,6 @@ import ( "time" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -33,12 +30,7 @@ func init() { } func getRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rto time.Duration) { - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("unexpected size for TCP_INFO, got %d bytes want %d bytes", got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) return time.Duration(info.RTO) * time.Microsecond } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index cccb0abc6..fe53e7061 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -20,6 +20,7 @@ import ( "time" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -29,7 +30,7 @@ func init() { } // dutSynSentState sets up the dut connection in SYN-SENT state. -func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, uint16) { +func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, int32, uint16, uint16) { t.Helper() dut := testbench.NewDUT(t) @@ -46,26 +47,29 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, t.Fatalf("expected SYN\n") } - return &dut, &conn, port, clientPort + return &dut, &conn, clientFD, port, clientPort } // TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. func TestTCPSynSentReset(t *testing.T) { - _, conn, _, _ := dutSynSentState(t) + dut, conn, fd, _, _ := dutSynSentState(t) defer conn.Close(t) conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. - // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } + info := dut.GetSockOptTCPInfo(t, fd) + if got, want := uint32(info.State), linux.TCP_CLOSE; got != want { + t.Fatalf("got %d want %d", got, want) + } } // TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED // transitions. func TestTCPSynSentRcvdReset(t *testing.T) { - dut, c, remotePort, clientPort := dutSynSentState(t) + dut, c, fd, remotePort, clientPort := dutSynSentState(t) defer c.Close(t) conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &remotePort, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &remotePort}) @@ -79,9 +83,12 @@ func TestTCPSynSentRcvdReset(t *testing.T) { } conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. - // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } + info := dut.GetSockOptTCPInfo(t, fd) + if got, want := uint32(info.State), linux.TCP_CLOSE; got != want { + t.Fatalf("got %d want %d", got, want) + } } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 729b4c63b..0582e16ce 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gbenchmark", "gtest", "select_arch", "select_system") package( default_visibility = ["//:sandbox"], @@ -520,13 +520,14 @@ cc_binary( srcs = ["concurrency.cc"], linkstatic = 1, deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", + gbenchmark, gtest, "//test/util:platform_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) @@ -1489,6 +1490,7 @@ cc_binary( "//test/util:cleanup", "//test/util:posix_error", "//test/util:pty_util", + "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", @@ -1515,7 +1517,8 @@ cc_binary( cc_binary( name = "partial_bad_buffer_test", testonly = 1, - srcs = ["partial_bad_buffer.cc"], + # Android does not support preadv or pwritev in r22. + srcs = select_system(linux = ["partial_bad_buffer.cc"]), linkstatic = 1, deps = [ ":socket_test_util", @@ -1574,6 +1577,7 @@ cc_binary( "@com_google_absl//absl/time", gtest, "//test/util:posix_error", + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -3671,7 +3675,8 @@ cc_binary( cc_binary( name = "sync_test", testonly = 1, - srcs = ["sync.cc"], + # Android does not support syncfs in r22. + srcs = select_system(linux = ["sync.cc"]), linkstatic = 1, deps = [ gtest, @@ -3784,10 +3789,9 @@ cc_binary( srcs = ["timers.cc"], linkstatic = 1, deps = [ - "//test/util:cleanup", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", + gbenchmark, gtest, + "//test/util:cleanup", "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -3795,6 +3799,8 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", ], ) @@ -3966,7 +3972,8 @@ cc_binary( cc_binary( name = "utimes_test", testonly = 1, - srcs = ["utimes.cc"], + # Android does not support futimesat in r22. + srcs = select_system(linux = ["utimes.cc"]), linkstatic = 1, deps = [ "//test/util:file_descriptor", @@ -4082,7 +4089,8 @@ cc_binary( cc_binary( name = "semaphore_test", testonly = 1, - srcs = ["semaphore.cc"], + # Android does not support XSI semaphores in r22. + srcs = select_system(linux = ["semaphore.cc"]), linkstatic = 1, deps = [ "//test/util:capability_util", diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc index a009ade7e..f29891571 100644 --- a/test/syscalls/linux/cgroup.cc +++ b/test/syscalls/linux/cgroup.cc @@ -227,6 +227,41 @@ TEST(Cgroup, MountRace) { EXPECT_NO_ERRNO(c.ContainsCallingProcess()); } +TEST(Cgroup, MountUnmountRace) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + const DisableSave ds; // Too many syscalls. + + auto mount_thread = [&mountpoint]() { + for (int i = 0; i < 100; ++i) { + mount("none", mountpoint.path().c_str(), "cgroup", 0, 0); + } + }; + auto unmount_thread = [&mountpoint]() { + for (int i = 0; i < 100; ++i) { + umount(mountpoint.path().c_str()); + } + }; + std::list<ScopedThread> threads; + for (int i = 0; i < 10; ++i) { + threads.emplace_back(mount_thread); + } + for (int i = 0; i < 10; ++i) { + threads.emplace_back(unmount_thread); + } + for (auto& t : threads) { + t.Join(); + } + + // We don't know how many mount refs are remaining, since the count depends on + // the ordering of mount and umount calls. Keep calling unmount until it + // returns an error. + while (umount(mountpoint.path().c_str()) == 0) { + } +} + TEST(Cgroup, UnmountRepeated) { SKIP_IF(!CgroupsAvailable()); diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index 7cd6a75bd..f2daf49ee 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -20,6 +20,7 @@ #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "benchmark/benchmark.h" #include "test/util/platform_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -106,6 +107,8 @@ TEST(ConcurrencyTest, MultiProcessConcurrency) { pid_t child_pid = fork(); if (child_pid == 0) { while (true) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(child_pid, SyscallSucceeds()); diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index af3d27894..3ef8b0327 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -230,6 +230,8 @@ TEST(EpollTest, WaitThenUnblock) { EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); } +#ifndef ANDROID // Android does not support pthread_cancel + void sighandler(int s) {} void* signaler(void* arg) { @@ -272,6 +274,8 @@ TEST(EpollTest, UnblockWithSignal) { EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); } +#endif // ANDROID + TEST(EpollTest, TimeoutNoFds) { auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); struct epoll_event result[kFDsPerEpoll]; diff --git a/test/syscalls/linux/exec_state_workload.cc b/test/syscalls/linux/exec_state_workload.cc index 028902b14..eafdc2bfa 100644 --- a/test/syscalls/linux/exec_state_workload.cc +++ b/test/syscalls/linux/exec_state_workload.cc @@ -26,6 +26,8 @@ #include "absl/strings/numbers.h" +#ifndef ANDROID // Conflicts with existing operator<< on Android. + // Pretty-print a sigset_t. std::ostream& operator<<(std::ostream& out, const sigset_t& s) { out << "{ "; @@ -40,6 +42,8 @@ std::ostream& operator<<(std::ostream& out, const sigset_t& s) { return out; } +#endif + // Verify that the signo handler is handler. int CheckSigHandler(uint32_t signo, uintptr_t handler) { struct sigaction sa; diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 98d07ae85..95082a0f2 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -174,13 +174,21 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); +#ifndef ANDROID RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); +#else + // Android does not support getifaddrs in r22. + return PosixError(ENOSYS, "getifaddrs"); +#endif return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { +#ifndef ANDROID + // Android does not support freeifaddrs in r22. freeifaddrs(ifaddr_); +#endif ifaddr_ = nullptr; } } diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index 6ce1e6cc3..d4f89527c 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -150,7 +150,7 @@ TEST(LseekTest, SeekCurrentDir) { // From include/linux/fs.h. constexpr loff_t MAX_LFS_FILESIZE = 0x7fffffffffffffff; - char* dir = get_current_dir_name(); + char* dir = getcwd(NULL, 0); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir, O_RDONLY)); ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 96c454485..294a72468 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -14,6 +14,7 @@ #include <fcntl.h> /* Obtain O_* constant definitions */ #include <linux/magic.h> +#include <signal.h> #include <sys/ioctl.h> #include <sys/statfs.h> #include <sys/uio.h> @@ -29,6 +30,7 @@ #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/posix_error.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -44,6 +46,28 @@ constexpr int kTestValue = 0x12345678; // Used for synchronization in race tests. const absl::Duration syncDelay = absl::Seconds(2); +std::atomic<int> global_num_signals_received = 0; +void SigRecordingHandler(int signum, siginfo_t* siginfo, + void* unused_ucontext) { + global_num_signals_received++; +} + +PosixErrorOr<Cleanup> RegisterSignalHandler(int signum) { + struct sigaction handler; + handler.sa_sigaction = SigRecordingHandler; + sigemptyset(&handler.sa_mask); + handler.sa_flags = SA_SIGINFO; + return ScopedSigaction(signum, handler); +} + +void WaitForSignalDelivery(absl::Duration timeout, int max_expected) { + absl::Time wait_start = absl::Now(); + while (global_num_signals_received < max_expected && + absl::Now() - wait_start < timeout) { + absl::SleepFor(absl::Milliseconds(10)); + } +} + struct PipeCreator { std::string name_; @@ -267,6 +291,9 @@ TEST_P(PipeTest, Seek) { } } +#ifndef ANDROID +// Android does not support preadv or pwritev in r22. + TEST_P(PipeTest, OffsetCalls) { SKIP_IF(!CreateBlocking()); @@ -283,6 +310,8 @@ TEST_P(PipeTest, OffsetCalls) { EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } +#endif // ANDROID + TEST_P(PipeTest, WriterSideCloses) { SKIP_IF(!CreateBlocking()); @@ -333,10 +362,16 @@ TEST_P(PipeTest, WriterSideClosesReadDataFirst) { TEST_P(PipeTest, ReaderSideCloses) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); int buf = kTestValue; EXPECT_THAT(write(wfd_.get(), &buf, sizeof(buf)), SyscallFailsWithErrno(EPIPE)); + + WaitForSignalDelivery(absl::Seconds(1), 1); + ASSERT_EQ(global_num_signals_received, 1); } TEST_P(PipeTest, CloseTwice) { @@ -355,6 +390,9 @@ TEST_P(PipeTest, CloseTwice) { TEST_P(PipeTest, BlockWriteClosed) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + absl::Notification notify; ScopedThread t([this, ¬ify]() { std::vector<char> buf(Size()); @@ -371,6 +409,10 @@ TEST_P(PipeTest, BlockWriteClosed) { notify.WaitForNotification(); ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); + + WaitForSignalDelivery(absl::Seconds(1), 1); + ASSERT_EQ(global_num_signals_received, 1); + t.Join(); } @@ -379,6 +421,9 @@ TEST_P(PipeTest, BlockWriteClosed) { TEST_P(PipeTest, BlockPartialWriteClosed) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + ScopedThread t([this]() { const int pipe_size = Size(); std::vector<char> buf(2 * pipe_size); @@ -396,6 +441,10 @@ TEST_P(PipeTest, BlockPartialWriteClosed) { // Unblock the above. ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); + + WaitForSignalDelivery(absl::Seconds(1), 2); + ASSERT_EQ(global_num_signals_received, 2); + t.Join(); } diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 8d15c491e..5ff1f12a0 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -40,6 +40,7 @@ #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" #include "test/util/pty_util.h" +#include "test/util/signal_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -387,6 +388,22 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, } TEST(PtyTrunc, Truncate) { + SKIP_IF(IsRunningWithVFS1()); + + // setsid either puts us in a new session or fails because we're already the + // session leader. Either way, this ensures we're the session leader and have + // no controlling terminal. + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect the TTY. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + const Cleanup cleanup = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGHUP, sa)); + // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to // (f)truncate should. FileDescriptor master = @@ -395,6 +412,7 @@ TEST(PtyTrunc, Truncate) { std::string spath = absl::StrCat("/dev/pts/", n); FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); + ASSERT_THAT(ioctl(replica.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); @@ -464,10 +482,10 @@ TEST(BasicPtyTest, OpenSetsControllingTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sa = {}; sa.sa_handler = SIG_IGN; sa.sa_flags = 0; @@ -491,7 +509,7 @@ TEST(BasicPtyTest, OpenMasterDoesNotSetsControllingTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); // Opening master does not set the controlling TTY, and therefore we are @@ -503,7 +521,7 @@ TEST(BasicPtyTest, OpenNOCTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE( OpenReplica(master, O_NOCTTY | O_NONBLOCK | O_RDWR)); @@ -1405,7 +1423,7 @@ TEST_F(JobControlTest, ReleaseTTY) { ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sa = {}; sa.sa_handler = SIG_IGN; sa.sa_flags = 0; @@ -1526,7 +1544,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sighup_sa = {}; sighup_sa.sa_handler = SIG_IGN; sighup_sa.sa_flags = 0; diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 7056342d7..7756af24d 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -157,7 +157,8 @@ TEST_F(ReadTest, PartialReadSIGSEGV) { .iov_len = size, }, }; - EXPECT_THAT(preadv(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0), + EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT(readv(fd.get(), iov, ABSL_ARRAYSIZE(iov)), SyscallSucceedsWithValue(size)); } diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc index f4ee775bd..ce5f63938 100644 --- a/test/syscalls/linux/socket_bind_to_device_util.cc +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -58,8 +58,10 @@ PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) { } std::unordered_set<string> GetInterfaceNames() { - struct if_nameindex* interfaces = if_nameindex(); std::unordered_set<string> names; +#ifndef ANDROID + // Android does not support if_nameindex in r22. + struct if_nameindex* interfaces = if_nameindex(); if (interfaces == nullptr) { return names; } @@ -68,6 +70,7 @@ std::unordered_set<string> GetInterfaceNames() { names.insert(interface->if_name); } if_freenameindex(interfaces); +#endif return names; } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 9a6b089f6..f99d6f1c7 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -472,6 +472,77 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } +// Test the protocol state information returned by TCPINFO. +TEST_P(SocketInetLoopbackTest, TCPInfoState) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + FileDescriptor const listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + + auto state = [](int fd) -> int { + struct tcp_info opt = {}; + socklen_t optLen = sizeof(opt); + EXPECT_THAT(getsockopt(fd, SOL_TCP, TCP_INFO, &opt, &optLen), + SyscallSucceeds()); + return opt.tcpi_state; + }; + ASSERT_EQ(state(listen_fd.get()), TCP_CLOSE); + + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_EQ(state(listen_fd.get()), TCP_CLOSE); + + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + ASSERT_EQ(state(listen_fd.get()), TCP_LISTEN); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_EQ(state(conn_fd.get()), TCP_CLOSE); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_EQ(state(conn_fd.get()), TCP_ESTABLISHED); + + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + ASSERT_EQ(state(accepted.get()), TCP_ESTABLISHED); + + ASSERT_THAT(close(accepted.release()), SyscallSucceeds()); + + struct pollfd pfd = { + .fd = conn_fd.get(), + .events = POLLIN | POLLRDHUP, + }; + constexpr int kTimeout = 10000; + int n = poll(&pfd, 1, kTimeout); + ASSERT_GE(n, 0) << strerror(errno); + ASSERT_EQ(n, 1); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/6015): Notify POLLRDHUP on incoming FIN. + ASSERT_EQ(pfd.revents, POLLIN); + } else { + ASSERT_EQ(pfd.revents, POLLIN | POLLRDHUP); + } + + ASSERT_THAT(state(conn_fd.get()), TCP_CLOSE_WAIT); + ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); +} + void TestHangupDuringConnect(const TestParam& param, void (*hangup)(FileDescriptor&)) { TestAddress const& listener = param.listener; diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc index 8390f7c3b..09f070797 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc @@ -38,8 +38,8 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { ipv6_mreq group_req = { .ipv6mr_multiaddr = reinterpret_cast<sockaddr_in6*>(&multicast_addr.addr)->sin6_addr, - .ipv6mr_interface = - (unsigned int)ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")), + .ipv6mr_interface = static_cast<decltype(ipv6_mreq::ipv6mr_interface)>( + ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"))), }; ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, &group_req, sizeof(group_req)), diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 93a98adb1..bc12dd4af 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -26,6 +26,7 @@ #include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "benchmark/benchmark.h" #include "test/util/cleanup.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" @@ -92,6 +93,8 @@ TEST(TimerTest, ProcessKilledOnCPUSoftLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); @@ -151,6 +154,8 @@ TEST(TimerTest, ProcessPingedRepeatedlyAfterCPUSoftLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); @@ -197,6 +202,8 @@ TEST(TimerTest, ProcessKilledOnCPUHardLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index 8eeabbc3d..c788654a8 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -709,8 +709,13 @@ func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { return } - // Recursively collect information from - // the other analyzers. + // If this package is the atomic package, the implementation + // may be replaced by instrinsics that don't have analysis. + if x.Pkg.Pkg.Path() == "sync/atomic" { + return + } + + // Recursively collect information. var imp packageEscapeFacts if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { // Unable to import the dependency; we must diff --git a/website/BUILD b/website/BUILD index 6f52e9208..1a38967e5 100644 --- a/website/BUILD +++ b/website/BUILD @@ -165,6 +165,7 @@ docs( "//g3doc/user_guide/tutorials:cni", "//g3doc/user_guide/tutorials:docker", "//g3doc/user_guide/tutorials:docker_compose", + "//g3doc/user_guide/tutorials:knative", "//g3doc/user_guide/tutorials:kubernetes", ], ) |